Skip to content

Commit

Permalink
get face restorers (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
Robothy authored Oct 15, 2023
1 parent 2d6c39e commit 174ae59
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 9 deletions.
11 changes: 11 additions & 0 deletions src/main/java/io/github/robothy/sdwebui/sdk/GetFaceRestorers.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.github.robothy.sdwebui.sdk;

import io.github.robothy.sdwebui.sdk.models.results.FaceRestorer;

import java.util.List;

public interface GetFaceRestorers {

List<FaceRestorer> getFaceRestorers();

}
2 changes: 1 addition & 1 deletion src/main/java/io/github/robothy/sdwebui/sdk/SdWebui.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import java.lang.reflect.Proxy;

public interface SdWebui extends SystemInfoFetcher, Txt2Image, Image2Image, GetSdModels {
public interface SdWebui extends SystemInfoFetcher, Txt2Image, Image2Image, GetSdModels, GetFaceRestorers {

static SdWebui create(String endpoint) {
SdWebuiOptions options = new SdWebuiOptions(endpoint);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.github.robothy.sdwebui.sdk.models.results;

import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Getter;

@Getter
public class FaceRestorer {

@JsonProperty("name")
private String name;

@JsonProperty("cmd_dir")
private String cmdDir;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.github.robothy.sdwebui.sdk.services;

import io.github.robothy.sdwebui.sdk.GetFaceRestorers;
import io.github.robothy.sdwebui.sdk.SdWebuiBeanContainer;
import io.github.robothy.sdwebui.sdk.models.results.FaceRestorer;

import java.util.Arrays;
import java.util.List;

public class DefaultGetFaceRestorersService implements GetFaceRestorers {

private final SdWebuiBeanContainer container;

public DefaultGetFaceRestorersService(SdWebuiBeanContainer container) {
this.container = container;
}

@Override
public List<FaceRestorer> getFaceRestorers() {
return Arrays.asList(container.getBean(CommonGetService.class).getData("/sdapi/v1/face-restorers", FaceRestorer[].class));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@ public void register(Class<?> serviceClass, Object service) {

private void init() {
CloseableHttpClient closeableHttpClient = HttpClients.createDefault();
this.services.put(SdWebuiOptions.class, sdWebuiOptions);
this.services.put(ObjectMapper.class, new ObjectMapper());
this.services.put(HttpClient.class, closeableHttpClient);
this.services.put(SystemInfo.class, new CacheableSystemInfoFetcher(sdWebuiOptions.getEndpoint(), this));
this.services.put(Txt2Image.class, new DefaultTxt2ImageService(this));
this.services.put(Image2Image.class, new DefaultImage2ImageService(this));
this.services.put(CommonGetService.class, new CommonGetService(this));
this.services.put(GetSdModels.class, new DefaultGetSdModelService(this));
register(SdWebuiOptions.class, sdWebuiOptions);
register(ObjectMapper.class, new ObjectMapper());
register(HttpClient.class, closeableHttpClient);
register(SystemInfo.class, new CacheableSystemInfoFetcher(sdWebuiOptions.getEndpoint(), this));
register(Txt2Image.class, new DefaultTxt2ImageService(this));
register(Image2Image.class, new DefaultImage2ImageService(this));
register(CommonGetService.class, new CommonGetService(this));
register(GetSdModels.class, new DefaultGetSdModelService(this));
register(GetFaceRestorers.class, new DefaultGetFaceRestorersService(this));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package io.github.robothy.sdwebui.sdk;

import io.github.robothy.sdwebui.sdk.models.results.FaceRestorer;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockserver.client.MockServerClient;
import org.mockserver.junit.jupiter.MockServerExtension;
import org.mockserver.model.HttpRequest;
import org.mockserver.model.HttpResponse;

import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

@ExtendWith(MockServerExtension.class)
class GetFaceRestorersTest {

@Test
void getFaceRestorers(MockServerClient client) {
client.when(new HttpRequest().withMethod("GET").withPath("/sdapi/v1/face-restorers"))
.respond(new HttpResponse().withStatusCode(200).withBody("[\n" +
" {\n" +
" \"name\": \"CodeFormer\",\n" +
" \"cmd_dir\": \"C:\\\\Users\\\\admin\\\\PythonProjects\\\\stable-diffusion-webui\\\\models\\\\Codeformer\"\n" +
" },\n" +
" {\n" +
" \"name\": \"GFPGAN\",\n" +
" \"cmd_dir\": null\n" +
" }\n" +
"]"));

List<FaceRestorer> faceRestorers = SdWebui.create("http://localhost:" + client.getPort())
.getFaceRestorers();
assertEquals(2, faceRestorers.size());
assertEquals("CodeFormer", faceRestorers.get(0).getName());
assertEquals("C:\\Users\\admin\\PythonProjects\\stable-diffusion-webui\\models\\Codeformer", faceRestorers.get(0).getCmdDir());
assertEquals("GFPGAN", faceRestorers.get(1).getName());
assertNull(faceRestorers.get(1).getCmdDir());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package io.github.robothy.sdwebui.sdk.models.results;

import io.github.robothy.sdwebui.sdk.utils.JsonUtils;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.*;

class FaceRestorerTest {

@Test
void testSerialization() {
String faceRestorersJson = "[\n" +
" {\n" +
" \"name\": \"CodeFormer\",\n" +
" \"cmd_dir\": \"C:\\\\Users\\\\admin\\\\PythonProjects\\\\stable-diffusion-webui\\\\models\\\\Codeformer\"\n" +
" },\n" +
" {\n" +
" \"name\": \"GFPGAN\",\n" +
" \"cmd_dir\": null\n" +
" }\n" +
"]";

FaceRestorer[] faceRestorers = JsonUtils.fromJson(faceRestorersJson, FaceRestorer[].class);
assertEquals(2, faceRestorers.length);
assertEquals("CodeFormer", faceRestorers[0].getName());
assertEquals("C:\\Users\\admin\\PythonProjects\\stable-diffusion-webui\\models\\Codeformer", faceRestorers[0].getCmdDir());
assertEquals("GFPGAN", faceRestorers[1].getName());
assertNull(faceRestorers[1].getCmdDir());
}

}

0 comments on commit 174ae59

Please sign in to comment.