Skip to content

Commit

Permalink
Add proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
wsargent authored and mkurz committed Jun 27, 2024
1 parent e58065a commit 7269da9
Show file tree
Hide file tree
Showing 8 changed files with 323 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Copyright (C) Lightbend Inc. <https://www.lightbend.com>
*/

package play.libs.ws.ahc;

import play.libs.ws.WSProxyServer;

import java.util.List;
import java.util.Objects;
import java.util.Optional;

public class DefaultWSProxyServer implements WSProxyServer {
private final String host;
private final int port;
private final String protocol;
private final String proxyType;
private final String principal;
private final String password;
private final String ntlmDomain;
private final List<String> nonProxyHosts;
private final String encoding;

DefaultWSProxyServer(String host,
Integer port,
String protocol,
String proxyType,
String principal,
String password,
String ntlmDomain,
List<String> nonProxyHosts,
String encoding) {
this.host = Objects.requireNonNull(host, "host cannot be null!");
this.port = Objects.requireNonNull(port, "port cannot be null");
this.protocol = protocol;
this.proxyType = proxyType;
this.principal = principal;
this.password = password;
this.ntlmDomain = ntlmDomain;
this.nonProxyHosts = nonProxyHosts;
this.encoding = encoding;
}

@Override
public String getHost() {
return this.host;
}

@Override
public int getPort() {
return this.port;
}

@Override
public Optional<String> getProtocol() {
return Optional.ofNullable(this.protocol);
}

@Override
public Optional<String> getProxyType() {
return Optional.ofNullable(this.proxyType);
}

@Override
public Optional<String> getPrincipal() {
return Optional.ofNullable(this.principal);
}

@Override
public Optional<String> getPassword() {
return Optional.ofNullable(this.password);
}

@Override
public Optional<String> getNtlmDomain() {
return Optional.ofNullable(this.ntlmDomain);
}

@Override
public Optional<String> getEncoding() {
return Optional.ofNullable(this.encoding);
}

@Override
public Optional<List<String>> getNonProxyHosts() {
return Optional.ofNullable(this.nonProxyHosts);
}

static Builder builder() {
return new Builder();
}

static class Builder {
private String host;
private Integer port;
private String protocol;
private String proxyType;
private String principal;
private String password;
private String ntlmDomain;
private List<String> nonProxyHosts;
private String encoding;

public Builder withHost(String host) {
this.host = host;
return this;
}

public Builder withPort(int port) {
this.port = port;
return this;
}

public Builder withProtocol(String protocol) {
this.protocol = protocol;
return this;
}

public Builder withProxyType(String proxyType) {
this.proxyType = proxyType;
return this;
}

public Builder withPrincipal(String principal) {
this.principal = principal;
return this;
}

public Builder withPassword(String password) {
this.password = password;
return this;
}

public Builder withNtlmDomain(String ntlmDomain) {
this.ntlmDomain = ntlmDomain;
return this;
}

public Builder withNonProxyHosts(List<String> nonProxyHosts) {
this.nonProxyHosts = nonProxyHosts;
return this;
}

public Builder withEncoding(String encoding) {
this.encoding = encoding;
return this;
}

public WSProxyServer build() {
return new DefaultWSProxyServer(host,
port,
protocol,
proxyType,
principal,
password,
ntlmDomain,
nonProxyHosts,
encoding);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,17 @@
import play.shaded.ahc.org.asynchttpclient.RequestBuilder;
import play.shaded.ahc.org.asynchttpclient.SignatureCalculator;

import play.shaded.ahc.org.asynchttpclient.proxy.ProxyServer;
import play.shaded.ahc.org.asynchttpclient.proxy.ProxyType;
import play.shaded.ahc.org.asynchttpclient.util.HttpUtils;

import java.net.MalformedURLException;
import java.net.Proxy;
import java.net.URL;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.*;
import java.util.concurrent.CompletionStage;

import static java.util.Collections.singletonList;
Expand All @@ -67,6 +62,8 @@ public class StandaloneAhcWSRequest implements StandaloneWSRequest {

private WSAuthInfo auth;
private WSSignatureCalculator calculator;
private WSProxyServer proxyServer;

private final StandaloneAhcWSClient client;

private final Materializer materializer;
Expand Down Expand Up @@ -243,6 +240,15 @@ public StandaloneAhcWSRequest setContentType(String contentType) {
return addHeader(CONTENT_TYPE.toString(), contentType);
}

@Override
public StandaloneWSRequest setProxyServer(WSProxyServer proxyServer) {
if (proxyServer == null) {
throw new IllegalArgumentException("proxyServer must not be null.");
}
this.proxyServer = proxyServer;
return this;
}

@Override
public Optional<String> getContentType() {
return getHeader(CONTENT_TYPE.toString());
Expand Down Expand Up @@ -324,6 +330,11 @@ public Optional<WSAuthInfo> getAuth() {
return Optional.ofNullable(this.auth);
}

@Override
public Optional<WSProxyServer> getProxyServer() {
return Optional.ofNullable(this.proxyServer);
}

@Override
public Optional<WSSignatureCalculator> getCalculator() {
return Optional.ofNullable(this.calculator);
Expand Down Expand Up @@ -527,9 +538,52 @@ Request buildRequest() {
builder.addCookie(ahcCookie);
});

getProxyServer().ifPresent(ps -> builder.setProxyServer(createProxy(ps)));

return builder.build();
}

private ProxyServer createProxy(WSProxyServer proxyServer) {
String host = proxyServer.getHost();
int port = proxyServer.getPort();
ProxyServer.Builder proxyBuilder = new ProxyServer.Builder(host, port);

proxyServer.getPrincipal().ifPresent(principal -> {
Realm.Builder realmBuilder = new Realm.Builder(principal, proxyServer.getPassword().orElse(null));
String protocol = proxyServer.getProtocol().orElse("http").toLowerCase(Locale.ENGLISH);
switch (protocol) {
case "http":
case "https":
realmBuilder.setScheme(Realm.AuthScheme.BASIC);
case "kerberos":
realmBuilder.setScheme(Realm.AuthScheme.KERBEROS);
case "ntlm":
realmBuilder.setScheme(Realm.AuthScheme.NTLM);
case "spnego":
realmBuilder.setScheme(Realm.AuthScheme.SPNEGO);
default:
// Default to BASIC rather than throwing an error.
realmBuilder.setScheme(Realm.AuthScheme.BASIC);
}
proxyServer.getEncoding().ifPresent(enc -> realmBuilder.setCharset(Charset.forName(enc)));
proxyServer.getNtlmDomain().ifPresent(realmBuilder::setNtlmDomain);
proxyBuilder.setRealm(realmBuilder);
});

String proxyType = proxyServer.getProxyType().orElse("http").toLowerCase(Locale.ENGLISH);
switch (proxyType) {
case "http":
proxyBuilder.setProxyType(ProxyType.HTTP);
case "socksv4":
proxyBuilder.setProxyType(ProxyType.SOCKS_V4);
case "socksv5":
proxyBuilder.setProxyType(ProxyType.SOCKS_V5);
}

proxyServer.getNonProxyHosts().ifPresent(proxyBuilder::setNonProxyHosts);
return proxyBuilder.build();
}

private static void addValueTo(Map<String, List<String>> map, String name, String value) {
final Optional<String> existing = map.keySet().stream().filter(s -> s.equalsIgnoreCase(name)).findAny();
if (existing.isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import play.shaded.ahc.io.netty.buffer.Unpooled
import play.shaded.ahc.io.netty.handler.codec.http.HttpHeaders
import play.shaded.ahc.org.asynchttpclient.Realm.AuthScheme
import play.shaded.ahc.org.asynchttpclient._
import play.shaded.ahc.org.asynchttpclient.proxy.ProxyType
import play.shaded.ahc.org.asynchttpclient.proxy.{ ProxyServer => AHCProxyServer }
import play.shaded.ahc.org.asynchttpclient.util.HttpUtils

Expand Down Expand Up @@ -450,6 +451,16 @@ case class StandaloneAhcWSRequest(
proxyBuilder.setRealm(realmBuilder)
}

val proxyType = wsProxyServer.proxyType.getOrElse("http").toLowerCase(java.util.Locale.ENGLISH) match {
case "http" =>
ProxyType.HTTP
case "socksv4" =>
ProxyType.SOCKS_V4
case "socksv5" =>
ProxyType.SOCKS_V5
}
proxyBuilder.setProxyType(proxyType);

wsProxyServer.nonProxyHosts.foreach { nonProxyHosts =>
import scala.jdk.CollectionConverters._
proxyBuilder.setNonProxyHosts(nonProxyHosts.asJava)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import org.mockito.Mockito
import org.specs2.execute.Result
import org.specs2.mutable.Specification
import org.specs2.specification.AfterAll

import play.api.libs.oauth.ConsumerKey
import play.api.libs.oauth.RequestToken
import play.api.libs.oauth.OAuthCalculator
Expand All @@ -24,6 +23,7 @@ import play.shaded.ahc.io.netty.handler.codec.http.HttpHeaderNames
import play.shaded.ahc.org.asynchttpclient.Realm.AuthScheme
import play.shaded.ahc.org.asynchttpclient.SignatureCalculator
import play.shaded.ahc.org.asynchttpclient.Param
import play.shaded.ahc.org.asynchttpclient.proxy.ProxyType
import play.shaded.ahc.org.asynchttpclient.{ Request => AHCRequest }

import scala.reflect.ClassTag
Expand Down Expand Up @@ -457,8 +457,10 @@ class AhcWSRequestSpec extends Specification with AfterAll with DefaultBodyReada
protocol = Some("https"),
host = "localhost",
port = 8080,
proxyType = Some("socksv5"),
principal = Some("principal"),
password = Some("password")
password = Some("password"),
nonProxyHosts = Some(List("derp"))
)
val req: AHCRequest = client
.url("http://playframework.com/")
Expand All @@ -472,6 +474,8 @@ class AhcWSRequestSpec extends Specification with AfterAll with DefaultBodyReada
(actual.getRealm.getPrincipal must be).equalTo("principal")
(actual.getRealm.getPassword must be).equalTo("password")
(actual.getRealm.getScheme must be).equalTo(AuthScheme.BASIC)
(actual.getProxyType must be).equalTo(ProxyType.SOCKS_V5)
(actual.getNonProxyHosts.asScala must contain("derp"))
}

"support a proxy server with NTLM" in withClient { client =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ import org.specs2.mutable._
import play.libs.oauth.OAuth
import play.libs.ws._
import play.shaded.ahc.io.netty.handler.codec.http.HttpHeaderNames
import play.shaded.ahc.org.asynchttpclient.Realm.AuthScheme
import play.shaded.ahc.org.asynchttpclient.Request
import play.shaded.ahc.org.asynchttpclient.RequestBuilderBase
import play.shaded.ahc.org.asynchttpclient.SignatureCalculator
import play.shaded.ahc.org.asynchttpclient.proxy.ProxyType

import scala.jdk.CollectionConverters._
import scala.collection.mutable
Expand Down Expand Up @@ -175,6 +177,36 @@ class AhcWSRequestSpec extends Specification with DefaultBodyReadables with Defa

}

"Use a proxy server" in {
val client = StandaloneAhcWSClient.create(
AhcWSClientConfigFactory.forConfig(ConfigFactory.load(), this.getClass.getClassLoader), /*materializer*/ null
)
val request = new StandaloneAhcWSRequest(client, "http://example.com", /*materializer*/ null)
val proxyServer = DefaultWSProxyServer
.builder()
.withHost("localhost")
.withPort(8080)
.withPrincipal("principal")
.withPassword("password")
.withProxyType("socksv5")
.withNonProxyHosts(java.util.Arrays.asList("derp"))
.build()

val req = request
.setProxyServer(proxyServer)
.asInstanceOf[StandaloneAhcWSRequest]
.buildRequest()
val actual = req.getProxyServer

(actual.getHost must be).equalTo("localhost")
(actual.getPort must be).equalTo(8080)
(actual.getRealm.getPrincipal must be).equalTo("principal")
(actual.getRealm.getPassword must be).equalTo("password")
(actual.getRealm.getScheme must be).equalTo(AuthScheme.BASIC)
(actual.getProxyType must be).equalTo(ProxyType.SOCKS_V5)
(actual.getNonProxyHosts.asScala must contain("derp"))
}

"Use a custom signature calculator" in {
val client = StandaloneAhcWSClient.create(
AhcWSClientConfigFactory.forConfig(ConfigFactory.load(), this.getClass.getClassLoader), /*materializer*/ null
Expand Down
Loading

0 comments on commit 7269da9

Please sign in to comment.