diff --git a/amoro-common/src/main/java/org/apache/amoro/client/AmsThriftUrl.java b/amoro-common/src/main/java/org/apache/amoro/client/AmsThriftUrl.java index 860d5c6dca..d5fe655202 100644 --- a/amoro-common/src/main/java/org/apache/amoro/client/AmsThriftUrl.java +++ b/amoro-common/src/main/java/org/apache/amoro/client/AmsThriftUrl.java @@ -45,7 +45,8 @@ public class AmsThriftUrl { public static final String THRIFT_URL_FORMAT = "thrift://%s:%d/%s%s"; public static final int MAX_RETRIES = 3; private static final Logger logger = LoggerFactory.getLogger(AmsThriftUrl.class); - private static final Pattern PATTERN = Pattern.compile("zookeeper://(\\S+)/([\\w-]+)"); + private static final Pattern PATTERN = + Pattern.compile("zookeeper://(\\S+)/([\\w-]+)", Pattern.CASE_INSENSITIVE); private final String schema; private final String host; private final int port; @@ -77,34 +78,28 @@ public static AmsThriftUrl parse(String url, String serviceName) { if (url == null) { throw new IllegalArgumentException("thrift url is null"); } - if (url.startsWith(ZOOKEEPER_FLAG)) { + String scheme = parseScheme(url); + if (ZOOKEEPER_FLAG.equalsIgnoreCase(scheme)) { return parserZookeeperUrl(url, serviceName); - } else { + } else if ("thrift".equalsIgnoreCase(scheme)) { return parserThriftUrl(url); } + throw new IllegalArgumentException( + String.format("Unsupported AMS URL scheme '%s' for url %s", scheme, url)); } private static AmsThriftUrl parserThriftUrl(String url) { - int socketTimeout = DEFAULT_SOCKET_TIMEOUT; try { - URI uri = new URI(url.toLowerCase(Locale.ROOT)); - String schema = uri.getScheme(); + URI uri = new URI(url); + String schema = validateScheme(uri, "thrift", url); String host = uri.getHost(); int port = uri.getPort(); + validateHostAndPort(host, port, url); String path = uri.getPath(); if (path != null && path.startsWith("/")) { path = path.substring(1); } - if (uri.getQuery() != null) { - for (String paramExpression : uri.getQuery().split("&")) { - String[] paramSplit = paramExpression.split("="); - if (paramSplit.length == 2) { - if (paramSplit[0].equalsIgnoreCase(PARAM_SOCKET_TIMEOUT)) { - socketTimeout = Integer.parseInt(paramSplit[1]); - } - } - } - } + int socketTimeout = parseSocketTimeout(uri.getQuery(), url); String catalogName = path; return new AmsThriftUrl(schema, host, port, catalogName, socketTimeout, url); } catch (URISyntaxException e) { @@ -143,20 +138,12 @@ private static AmsThriftUrl parserZookeeperUrl(String url, String serviceName) { serverInfo.getThriftBindPort(), catalog, query); - int socketTimeout = DEFAULT_SOCKET_TIMEOUT; - for (String paramExpression : query.replace("?", "").split("&")) { - String[] paramSplit = paramExpression.split("="); - if (paramSplit.length == 2) { - if (paramSplit[0].equalsIgnoreCase(PARAM_SOCKET_TIMEOUT)) { - socketTimeout = Integer.parseInt(paramSplit[1]); - } - } - } + int socketTimeout = parseSocketTimeout(query.replace("?", ""), url); return new AmsThriftUrl( "thrift", serverInfo.getHost(), serverInfo.getThriftBindPort(), - catalog.toLowerCase(), + catalog, socketTimeout, url); } catch (KeeperException.AuthFailedException authFailedException) { @@ -213,14 +200,15 @@ public static List parseMasterSlaveAmsNodes(String url) { if (url == null) { throw new IllegalArgumentException("thrift url is null"); } + String scheme = parseScheme(url); + if (!ZOOKEEPER_FLAG.equalsIgnoreCase(scheme)) { + throw new IllegalArgumentException( + "parseMasterSlaveAmsNodes only supports ZooKeeper URL format: zookeeper://host:port/cluster"); + } return parserZookeeperUrlListForMasterSlaveMode(url); } private static List parserZookeeperUrlListForMasterSlaveMode(String url) { - if (!url.startsWith(ZOOKEEPER_FLAG)) { - throw new IllegalArgumentException( - "parseMasterSlaveAmsNodes only supports ZooKeeper URL format: zookeeper://host:port/cluster"); - } String thriftUrl = url; if (url.contains("?")) { thriftUrl = url.substring(0, url.indexOf("?")); @@ -298,6 +286,59 @@ private static List parserZookeeperUrlListForMasterSlaveMode(Stri return serverInfoList; } + private static String parseScheme(String url) { + try { + URI uri = new URI(url); + String scheme = uri.getScheme(); + if (scheme == null || scheme.trim().isEmpty()) { + throw new IllegalArgumentException("AMS URL scheme is required for url " + url); + } + return scheme; + } catch (URISyntaxException e) { + throw new IllegalArgumentException("parse metastore url failed", e); + } + } + + private static String validateScheme(URI uri, String expectedScheme, String url) { + String scheme = uri.getScheme(); + if (scheme == null || !expectedScheme.equalsIgnoreCase(scheme)) { + throw new IllegalArgumentException( + String.format("Unsupported AMS URL scheme '%s' for url %s", scheme, url)); + } + return scheme.toLowerCase(Locale.ROOT); + } + + private static void validateHostAndPort(String host, int port, String url) { + if (host == null || host.trim().isEmpty()) { + throw new IllegalArgumentException("AMS thrift url host is required: " + url); + } + if (port < 0) { + throw new IllegalArgumentException("AMS thrift url port is required: " + url); + } + } + + private static int parseSocketTimeout(String query, String url) { + int socketTimeout = DEFAULT_SOCKET_TIMEOUT; + if (query == null || query.isEmpty()) { + return socketTimeout; + } + for (String paramExpression : query.split("&")) { + if (paramExpression.isEmpty()) { + continue; + } + String[] paramSplit = paramExpression.split("=", 2); + if (paramSplit.length == 2 && paramSplit[0].equalsIgnoreCase(PARAM_SOCKET_TIMEOUT)) { + try { + socketTimeout = Integer.parseInt(paramSplit[1]); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + String.format("Invalid socketTimeout '%s' in AMS url %s", paramSplit[1], url), e); + } + } + } + return socketTimeout; + } + public String schema() { return schema; } diff --git a/amoro-common/src/test/java/org/apache/amoro/client/TestAmsThriftUrl.java b/amoro-common/src/test/java/org/apache/amoro/client/TestAmsThriftUrl.java new file mode 100644 index 0000000000..efcc9f000f --- /dev/null +++ b/amoro-common/src/test/java/org/apache/amoro/client/TestAmsThriftUrl.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.amoro.client; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class TestAmsThriftUrl { + + @Test + public void testParseThriftUrlPreservesCatalogCase() { + AmsThriftUrl thriftUrl = + AmsThriftUrl.parse("ThRiFt://LOCALHOST:1260/MyCatalog?socketTimeout=6000", null); + + Assertions.assertEquals("thrift", thriftUrl.schema()); + Assertions.assertEquals("LOCALHOST", thriftUrl.host()); + Assertions.assertEquals(1260, thriftUrl.port()); + Assertions.assertEquals("MyCatalog", thriftUrl.catalogName()); + Assertions.assertEquals(6000, thriftUrl.socketTimeout()); + Assertions.assertEquals( + "ThRiFt://LOCALHOST:1260/MyCatalog?socketTimeout=6000", thriftUrl.url()); + } + + @Test + public void testParseThriftUrlSocketTimeoutParameterIsCaseInsensitive() { + AmsThriftUrl thriftUrl = + AmsThriftUrl.parse("thrift://127.0.0.1:1260/catalog?SocketTimeout=7000", null); + + Assertions.assertEquals(7000, thriftUrl.socketTimeout()); + } + + @Test + public void testParseThriftUrlRejectsUnsupportedScheme() { + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> AmsThriftUrl.parse("http://127.0.0.1:1260/catalog", null)); + + Assertions.assertTrue(exception.getMessage().contains("Unsupported AMS URL scheme")); + } + + @Test + public void testParseThriftUrlRejectsMissingHost() { + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> AmsThriftUrl.parse("thrift:///catalog", null)); + + Assertions.assertTrue(exception.getMessage().contains("host is required")); + } + + @Test + public void testParseThriftUrlRejectsMissingPort() { + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> AmsThriftUrl.parse("thrift://127.0.0.1/catalog", null)); + + Assertions.assertTrue(exception.getMessage().contains("port is required")); + } +}