Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImageResponseFormat;

/**
* The configuration information for a image generation request.
Expand Down Expand Up @@ -81,7 +82,7 @@ public class AzureOpenAiImageOptions implements ImageOptions {
* b64_json.
*/
@JsonProperty("response_format")
private String responseFormat;
private ImageResponseFormat responseFormat;

/**
* The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024 for
Expand Down Expand Up @@ -150,13 +151,21 @@ public void setHeight(Integer height) {

@Override
public String getResponseFormat() {
return (this.responseFormat != null) ? this.responseFormat.getValue() : null;
}

public ImageResponseFormat getResponseFormatEnum() {
return this.responseFormat;
}

public void setResponseFormat(String responseFormat) {
public void setResponseFormat(ImageResponseFormat responseFormat) {
this.responseFormat = responseFormat;
}

public void setResponseFormat(String responseFormat) {
this.responseFormat = ImageResponseFormat.fromValue(responseFormat);
}

public String getSize() {
if (this.size != null) {
return this.size;
Expand Down Expand Up @@ -279,6 +288,11 @@ public Builder deploymentName(String deploymentName) {
return this;
}

public Builder responseFormat(ImageResponseFormat responseFormat) {
this.options.setResponseFormat(responseFormat);
return this;
}

public Builder responseFormat(String responseFormat) {
this.options.setResponseFormat(responseFormat);
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImageResponseFormat;

/**
* OpenAI Image API options. OpenAiImageOptions.java
Expand Down Expand Up @@ -79,7 +80,7 @@ public class OpenAiImageOptions implements ImageOptions {
* b64_json.
*/
@JsonProperty("response_format")
private String responseFormat;
private ImageResponseFormat responseFormat;

/**
* The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024 for
Expand Down Expand Up @@ -159,13 +160,21 @@ public void setQuality(String quality) {

@Override
public String getResponseFormat() {
return (this.responseFormat != null) ? this.responseFormat.getValue() : null;
}

public ImageResponseFormat getResponseFormatEnum() {
return this.responseFormat;
}

public void setResponseFormat(String responseFormat) {
public void setResponseFormat(ImageResponseFormat responseFormat) {
this.responseFormat = responseFormat;
}

public void setResponseFormat(String responseFormat) {
this.responseFormat = ImageResponseFormat.fromValue(responseFormat);
}

@Override
public Integer getWidth() {
if (this.width != null) {
Expand Down Expand Up @@ -326,6 +335,11 @@ public Builder quality(String quality) {
return this;
}

public Builder responseFormat(ImageResponseFormat responseFormat) {
this.options.setResponseFormat(responseFormat);
return this;
}

public Builder responseFormat(String responseFormat) {
this.options.setResponseFormat(responseFormat);
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import org.junit.jupiter.api.Test;

import org.springframework.ai.image.ImageResponseFormat;

import static org.assertj.core.api.Assertions.assertThat;

/**
Expand All @@ -34,7 +36,7 @@ void testBuilderWithAllFields() {
.N(2)
.model("dall-e-3")
.quality("hd")
.responseFormat("url")
.responseFormat(ImageResponseFormat.URL)
.width(1024)
.height(1024)
.style("vivid")
Expand All @@ -45,6 +47,7 @@ void testBuilderWithAllFields() {
assertThat(options.getModel()).isEqualTo("dall-e-3");
assertThat(options.getQuality()).isEqualTo("hd");
assertThat(options.getResponseFormat()).isEqualTo("url");
assertThat(options.getResponseFormatAsEnum()).isEqualTo(ImageResponseFormat.URL);
assertThat(options.getWidth()).isEqualTo(1024);
assertThat(options.getHeight()).isEqualTo(1024);
assertThat(options.getSize()).isEqualTo("1024x1024");
Expand All @@ -58,7 +61,7 @@ void testCopy() {
.N(3)
.model("dall-e-3")
.quality("standard")
.responseFormat("b64_json")
.responseFormat(ImageResponseFormat.B64_JSON)
.width(1792)
.height(1024)
.style("natural")
Expand All @@ -72,6 +75,7 @@ void testCopy() {
assertThat(copied.getModel()).isEqualTo(original.getModel());
assertThat(copied.getQuality()).isEqualTo(original.getQuality());
assertThat(copied.getResponseFormat()).isEqualTo(original.getResponseFormat());
assertThat(copied.getResponseFormatAsEnum()).isEqualTo(original.getResponseFormatAsEnum());
assertThat(copied.getWidth()).isEqualTo(original.getWidth());
assertThat(copied.getHeight()).isEqualTo(original.getHeight());
assertThat(copied.getSize()).isEqualTo(original.getSize());
Expand All @@ -85,6 +89,7 @@ void testCopy() {
assertThat(copiedViaMethod.getModel()).isEqualTo(original.getModel());
assertThat(copiedViaMethod.getQuality()).isEqualTo(original.getQuality());
assertThat(copiedViaMethod.getResponseFormat()).isEqualTo(original.getResponseFormat());
assertThat(copiedViaMethod.getResponseFormatAsEnum()).isEqualTo(original.getResponseFormatAsEnum());
assertThat(copiedViaMethod.getWidth()).isEqualTo(original.getWidth());
assertThat(copiedViaMethod.getHeight()).isEqualTo(original.getHeight());
assertThat(copiedViaMethod.getSize()).isEqualTo(original.getSize());
Expand All @@ -99,7 +104,7 @@ void testSetters() {
options.setN(4);
options.setModel("dall-e-2");
options.setQuality("standard");
options.setResponseFormat("url");
options.setResponseFormat(ImageResponseFormat.URL);
options.setWidth(512);
options.setHeight(512);
options.setStyle("vivid");
Expand All @@ -109,6 +114,7 @@ void testSetters() {
assertThat(options.getModel()).isEqualTo("dall-e-2");
assertThat(options.getQuality()).isEqualTo("standard");
assertThat(options.getResponseFormat()).isEqualTo("url");
assertThat(options.getResponseFormatAsEnum()).isEqualTo(ImageResponseFormat.URL);
assertThat(options.getWidth()).isEqualTo(512);
assertThat(options.getHeight()).isEqualTo(512);
assertThat(options.getSize()).isEqualTo("512x512");
Expand Down Expand Up @@ -212,7 +218,7 @@ void testFluentApiPattern() {
.N(1)
.model("dall-e-3")
.quality("hd")
.responseFormat("url")
.responseFormat(ImageResponseFormat.URL)
.width(1024)
.height(1024)
.style("vivid")
Expand All @@ -223,6 +229,7 @@ void testFluentApiPattern() {
assertThat(options.getModel()).isEqualTo("dall-e-3");
assertThat(options.getQuality()).isEqualTo("hd");
assertThat(options.getResponseFormat()).isEqualTo("url");
assertThat(options.getResponseFormatAsEnum()).isEqualTo(ImageResponseFormat.URL);
assertThat(options.getWidth()).isEqualTo(1024);
assertThat(options.getHeight()).isEqualTo(1024);
assertThat(options.getSize()).isEqualTo("1024x1024");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.image.ImageResponseFormat;
import org.springframework.ai.image.observation.DefaultImageModelObservationConvention;
import org.springframework.ai.model.SimpleApiKey;
import org.springframework.ai.observation.conventions.AiOperationType;
Expand Down Expand Up @@ -61,7 +62,7 @@ void observationForImageOperation() {
.model(OpenAiImageApi.ImageModel.DALL_E_3.getValue())
.height(1024)
.width(1024)
.responseFormat("url")
.responseFormat(ImageResponseFormat.URL)
.style("natural")
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImageResponseFormat;
import org.springframework.ai.stabilityai.StyleEnum;

/**
Expand Down Expand Up @@ -122,7 +123,7 @@ public class StabilityAiImageOptions implements ImageOptions {
* accept header. Must be "application/json" or "image/png"
*/
@JsonProperty("response_format")
private String responseFormat;
private ImageResponseFormat responseFormat;

/**
* The strictness level of the diffusion process adherence to the prompt text.
Expand Down Expand Up @@ -329,13 +330,21 @@ public void setHeight(Integer height) {

@Override
public String getResponseFormat() {
return (this.responseFormat != null) ? this.responseFormat.getValue() : null;
}

public ImageResponseFormat getResponseFormatEnum() {
return this.responseFormat;
}

public void setResponseFormat(String responseFormat) {
public void setResponseFormat(ImageResponseFormat responseFormat) {
this.responseFormat = responseFormat;
}

public void setResponseFormat(String responseFormat) {
this.responseFormat = ImageResponseFormat.fromValue(responseFormat);
}

public Float getCfgScale() {
return this.cfgScale;
}
Expand Down Expand Up @@ -455,6 +464,11 @@ public Builder height(Integer height) {
return this;
}

public Builder responseFormat(ImageResponseFormat responseFormat) {
this.options.setResponseFormat(responseFormat);
return this;
}

public Builder responseFormat(String responseFormat) {
this.options.setResponseFormat(responseFormat);
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.junit.jupiter.api.Test;

import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImageResponseFormat;
import org.springframework.ai.stabilityai.api.StabilityAiApi;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;

Expand All @@ -37,7 +38,7 @@ void shouldPreferRuntimeOptionsOverDefaultOptions() {
.model("default-model")
.width(512)
.height(512)
.responseFormat("image/png")
.responseFormat(ImageResponseFormat.IMAGE_PNG)
.cfgScale(7.0f)
.clipGuidancePreset("FAST_BLUE")
.sampler("DDIM")
Expand All @@ -52,7 +53,7 @@ void shouldPreferRuntimeOptionsOverDefaultOptions() {
.model("runtime-model")
.width(1024)
.height(768)
.responseFormat("application/json")
.responseFormat(ImageResponseFormat.APPLICATION_JSON)
.cfgScale(14.0f)
.clipGuidancePreset("FAST_GREEN")
.sampler("DDPM")
Expand All @@ -72,6 +73,7 @@ void shouldPreferRuntimeOptionsOverDefaultOptions() {
assertThat(options.getWidth()).isEqualTo(1024);
assertThat(options.getHeight()).isEqualTo(768);
assertThat(options.getResponseFormat()).isEqualTo("application/json");
assertThat(options.getResponseFormatAsEnum()).isEqualTo(ImageResponseFormat.APPLICATION_JSON);
assertThat(options.getCfgScale()).isEqualTo(14.0f);
assertThat(options.getClipGuidancePreset()).isEqualTo("FAST_GREEN");
assertThat(options.getSampler()).isEqualTo("DDPM");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ public interface ImageOptions extends ModelOptions {

String getResponseFormat(); // openai - url or base64 : stability ai byte[] or base64

default ImageResponseFormat getResponseFormatAsEnum(); // convenience conversion helper

}
----

Expand All @@ -112,6 +114,10 @@ public class ImageResponse implements ModelResponse<ImageGeneration> {

private final List<ImageGeneration> imageGenerations;

Optional<byte[]> getResultAsBytes();

List<byte[]> getResultsAsBytes();

@Override
public ImageGeneration getResult() {
// get the first result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@

package org.springframework.ai.image;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.Base64;
import java.util.Objects;
import java.util.Optional;

import org.springframework.util.StringUtils;

public class Image {

Expand Down Expand Up @@ -72,4 +78,15 @@ public int hashCode() {
return Objects.hash(this.url, this.b64Json);
}

public Optional<byte[]> getB64JsonAsBytes() {
if (!StringUtils.hasText(this.b64Json)) {
return Optional.empty();
}
return Optional.of(Base64.getDecoder().decode(this.b64Json));
}

public Optional<InputStream> getB64JsonAsInputStream() {
return getB64JsonAsBytes().map(ByteArrayInputStream::new);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import org.springframework.ai.model.ModelOptions;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

/**
* ImageOptions represent the common options, portable across different image generation
Expand All @@ -40,6 +41,14 @@ public interface ImageOptions extends ModelOptions {
@Nullable
String getResponseFormat();

default @Nullable ImageResponseFormat getResponseFormatAsEnum() {
String responseFormat = getResponseFormat();
if (!StringUtils.hasText(responseFormat)) {
return null;
}
return ImageResponseFormat.fromValue(responseFormat);
}

@Nullable
String getStyle();

Expand Down
Loading