Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/main/java/io/weaviate/client6/v1/api/Authentication.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ public static Authentication resourceOwnerPassword(String username, String passw
};
}

/**
* Authenticate using Resource Owner Password authorization grant with client secret.
*
* @param clientSecret Client secret.
* @param username Resource owner username.
* @param password Resource owner password.
* @param scopes Client scopes.
*
* @return Authentication provider.
* @throws WeaviateOAuthException if an error occurred at any point of the token
* exchange process.
*/
public static Authentication resourceOwnerPassword(String clientSecret, String username, String password,
List<String> scopes) {
return transport -> {
OidcConfig oidc = OidcUtils.getConfig(transport).withScopes(scopes).withScopes("offline_access");
return TokenProvider.resourceOwnerPassword(oidc, clientSecret, username, password);
};
}

/**
* Authenticate using Client Credentials authorization grant.
*
Expand Down
21 changes: 17 additions & 4 deletions src/main/java/io/weaviate/client6/v1/api/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import io.weaviate.client6.v1.internal.BuildInfo;
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.Proxy;
import io.weaviate.client6.v1.internal.Timeout;
import io.weaviate.client6.v1.internal.TokenProvider;
import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions;
Expand All @@ -23,7 +24,8 @@ public record Config(
Map<String, String> headers,
Authentication authentication,
TrustManagerFactory trustManagerFactory,
Timeout timeout) {
Timeout timeout,
Proxy proxy) {

public static Config of(Function<Custom, ObjectBuilder<Config>> fn) {
return fn.apply(new Custom()).build();
Expand All @@ -39,23 +41,24 @@ private Config(Builder<?> builder) {
builder.headers,
builder.authentication,
builder.trustManagerFactory,
builder.timeout);
builder.timeout,
builder.proxy);
}

RestTransportOptions restTransportOptions() {
return restTransportOptions(null);
}

RestTransportOptions restTransportOptions(TokenProvider tokenProvider) {
return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider, trustManagerFactory, timeout);
return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider, trustManagerFactory, timeout, proxy);
}

GrpcChannelOptions grpcTransportOptions() {
return grpcTransportOptions(null);
}

GrpcChannelOptions grpcTransportOptions(TokenProvider tokenProvider) {
return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider, trustManagerFactory, timeout);
return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider, trustManagerFactory, timeout, proxy);
}

private abstract static class Builder<SelfT extends Builder<SelfT>> implements ObjectBuilder<Config> {
Expand All @@ -69,6 +72,7 @@ private abstract static class Builder<SelfT extends Builder<SelfT>> implements O
protected TrustManagerFactory trustManagerFactory;
protected Timeout timeout = new Timeout();
protected Map<String, String> headers = new HashMap<>();
protected Proxy proxy;

/**
* Set URL scheme. Subclasses may increase the visibility of this method to
Expand Down Expand Up @@ -174,6 +178,15 @@ public SelfT timeout(int initSeconds, int querySeconds, int insertSeconds) {
return (SelfT) this;
}

/**
* Set proxy for all requests.
*/
@SuppressWarnings("unchecked")
public SelfT proxy(Proxy proxy) {
this.proxy = proxy;
return (SelfT) this;
}

/**
* Weaviate will use the URL in this header to call Weaviate Embeddings
* Service if an appropriate vectorizer is configured for collection.
Expand Down
7 changes: 4 additions & 3 deletions src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
import io.weaviate.client6.v1.internal.rest.DefaultRestTransport;
import io.weaviate.client6.v1.internal.rest.RestTransport;
import io.weaviate.client6.v1.internal.rest.RestTransportOptions;
import lombok.Getter;

public class WeaviateClient implements AutoCloseable {
/** Store this for {@link #async()} helper. */
@Getter
private final Config config;

private final RestTransport restTransport;
Expand Down Expand Up @@ -63,14 +65,13 @@ public class WeaviateClient implements AutoCloseable {
public final WeaviateClusterClient cluster;

public WeaviateClient(Config config) {
RestTransportOptions restOpt;
RestTransportOptions restOpt = config.restTransportOptions();
GrpcChannelOptions grpcOpt;
if (config.authentication() == null) {
restOpt = config.restTransportOptions();
grpcOpt = config.grpcTransportOptions();
} else {
TokenProvider tokenProvider;
try (final var noAuthRest = new DefaultRestTransport(config.restTransportOptions())) {
try (final var noAuthRest = new DefaultRestTransport(restOpt)) {
tokenProvider = config.authentication().getTokenProvider(noAuthRest);
} catch (Exception e) {
// Generally exceptions are caught in TokenProvider internals.
Expand Down
45 changes: 45 additions & 0 deletions src/main/java/io/weaviate/client6/v1/internal/Proxy.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package io.weaviate.client6.v1.internal;

import javax.annotation.Nullable;

public class Proxy {
private final String host;
private final int port;
private final String scheme;
private final String username;
private final String password;

public Proxy(String host, int port, String scheme, @Nullable String username, @Nullable String password) {
this.host = host;
this.port = port;
this.scheme = scheme;
this.username = username;
this.password = password;
}

public Proxy(String host, int port) {
this(host, port, "http", null, null);
}

public String host() {
return host;
}

public int port() {
return port;
}

public String scheme() {
return scheme;
}

@Nullable
public String username() {
return username;
}

@Nullable
public String password() {
return password;
}
}
19 changes: 19 additions & 0 deletions src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,25 @@ public static TokenProvider resourceOwnerPassword(OidcConfig oidc, String userna
return background(reuse(null, exchange(oidc, passwordGrant), DEFAULT_EARLY_EXPIRY));
}

/**
* Create a TokenProvider that uses Resource Owner Password authorization grant
* with client secret.
*
* @param oidc OIDC config.
* @param clientSecret Client secret.
* @param username Resource owner username.
* @param password Resource owner password.
*
* @return Internal TokenProvider implementation.
* @throws WeaviateOAuthException if an error occurred at any point of the token
* exchange process.
*/
public static TokenProvider resourceOwnerPassword(OidcConfig oidc, String clientSecret, String username,
String password) {
final var passwordGrant = NimbusTokenProvider.resourceOwnerPassword(oidc, clientSecret, username, password);
return background(reuse(null, exchange(oidc, passwordGrant), DEFAULT_EARLY_EXPIRY));
}

/**
* Create a TokenProvider that uses Client Credentials authorization grant.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ public abstract class TransportOptions<H> {
protected final H headers;
protected final TrustManagerFactory trustManagerFactory;
protected final Timeout timeout;
protected final Proxy proxy;

protected TransportOptions(String scheme, String host, int port, H headers, TokenProvider tokenProvider,
TrustManagerFactory tmf, Timeout timeout) {
TrustManagerFactory tmf, Timeout timeout, Proxy proxy) {
this.scheme = scheme;
this.host = host;
this.port = port;
this.tokenProvider = tokenProvider;
this.headers = headers;
this.timeout = timeout;
this.trustManagerFactory = tmf;
this.proxy = proxy;
}

public boolean isSecure() {
Expand Down Expand Up @@ -57,4 +59,9 @@ public H headers() {
public TrustManagerFactory trustManagerFactory() {
return this.trustManagerFactory;
}

@Nullable
public Proxy proxy() {
return this.proxy;
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
package io.weaviate.client6.v1.internal.grpc;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

import javax.net.ssl.SSLException;

import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;

import io.grpc.HttpConnectProxiedSocketAddress;
import io.grpc.ManagedChannel;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
Expand All @@ -17,10 +12,17 @@
import io.grpc.stub.AbstractStub;
import io.grpc.stub.MetadataUtils;
import io.weaviate.client6.v1.api.WeaviateApiException;
import io.weaviate.client6.v1.internal.Proxy;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub;

import javax.net.ssl.SSLException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

public final class DefaultGrpcTransport implements GrpcTransport {
private final ManagedChannel channel;

Expand Down Expand Up @@ -88,7 +90,7 @@ public <RequestT, RequestM, ReplyM, ResponseT> CompletableFuture<ResponseT> perf
var method = rpc.methodAsync();
var stub = applyTimeout(futureStub, rpc);
var reply = method.apply(stub, message);
return toCompletableFuture(reply).thenApply(r -> rpc.unmarshal(r));
return toCompletableFuture(reply).thenApply(rpc::unmarshal);
}

/**
Expand Down Expand Up @@ -139,6 +141,27 @@ private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions)
channel.sslContext(sslCtx);
}

if (transportOptions.proxy() != null) {
Proxy proxy = transportOptions.proxy();
if ("http".equals(proxy.scheme())) {
final SocketAddress proxyAddress = new InetSocketAddress(proxy.host(), proxy.port());
channel.proxyDetector(targetAddress -> {
if (targetAddress instanceof InetSocketAddress) {
HttpConnectProxiedSocketAddress.Builder builder = HttpConnectProxiedSocketAddress.newBuilder()
.setProxyAddress(proxyAddress)
.setTargetAddress((InetSocketAddress) targetAddress);

if (proxy.username() != null && proxy.password() != null) {
builder.setUsername(proxy.username());
builder.setPassword(proxy.password());
}
return builder.build();
}
return null;
});
}
}

channel.intercept(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers()));

return channel.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import javax.net.ssl.TrustManagerFactory;

import io.grpc.Metadata;
import io.weaviate.client6.v1.internal.Proxy;
import io.weaviate.client6.v1.internal.Timeout;
import io.weaviate.client6.v1.internal.TokenProvider;
import io.weaviate.client6.v1.internal.TransportOptions;
Expand All @@ -13,19 +14,19 @@ public class GrpcChannelOptions extends TransportOptions<Metadata> {
private final Integer maxMessageSize;

public GrpcChannelOptions(String scheme, String host, int port, Map<String, String> headers,
TokenProvider tokenProvider, TrustManagerFactory tmf, Timeout timeout) {
this(scheme, host, port, buildMetadata(headers), tokenProvider, tmf, null, timeout);
TokenProvider tokenProvider, TrustManagerFactory tmf, Timeout timeout, Proxy proxy) {
this(scheme, host, port, buildMetadata(headers), tokenProvider, tmf, null, timeout, proxy);
}

private GrpcChannelOptions(String scheme, String host, int port, Metadata headers,
TokenProvider tokenProvider, TrustManagerFactory tmf, Integer maxMessageSize, Timeout timeout) {
super(scheme, host, port, headers, tokenProvider, tmf, timeout);
TokenProvider tokenProvider, TrustManagerFactory tmf, Integer maxMessageSize, Timeout timeout, Proxy proxy) {
super(scheme, host, port, headers, tokenProvider, tmf, timeout, proxy);
this.maxMessageSize = maxMessageSize;
}

public GrpcChannelOptions withMaxMessageSize(int maxMessageSize) {
return new GrpcChannelOptions(scheme, host, port, headers, tokenProvider, trustManagerFactory, maxMessageSize,
timeout);
timeout, proxy);
}

public Integer maxMessageSize() {
Expand Down
24 changes: 20 additions & 4 deletions src/main/java/io/weaviate/client6/v1/internal/oidc/OidcConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,32 @@
public record OidcConfig(
String clientId,
String providerMetadata,
Set<String> scopes) {
Set<String> scopes,
OidcProxy proxy) {

public OidcConfig(String clientId, String providerMetadata, Set<String> scopes) {
public record OidcProxy(
String host,
int port,
String scheme) {
}

public OidcConfig(String clientId, String providerMetadata, Set<String> scopes, OidcProxy proxy) {
this.clientId = clientId;
this.providerMetadata = providerMetadata;
this.scopes = scopes != null ? Set.copyOf(scopes) : Collections.emptySet();
this.proxy = proxy;
}

public OidcConfig(String clientId, String providerMetadata, Set<String> scopes) {
this(clientId, providerMetadata, scopes, null);
}

public OidcConfig(String clientId, String providerMetadata, List<String> scopes) {
this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes));
this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes), null);
}

public OidcConfig(String clientId, String providerMetadata, List<String> scopes, OidcProxy proxy) {
this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes), proxy);
}

/** Create a new OIDC config with extended scopes. */
Expand All @@ -31,6 +47,6 @@ public OidcConfig withScopes(String... scopes) {
/** Create a new OIDC config with extended scopes. */
public OidcConfig withScopes(List<String> scopes) {
var newScopes = Stream.concat(this.scopes.stream(), scopes.stream()).collect(Collectors.toSet());
return new OidcConfig(clientId, providerMetadata, newScopes);
return new OidcConfig(clientId, providerMetadata, newScopes, proxy);
}
}
Loading
Loading