From 49b34ed9025590c3cb192161aaf79c2037a4f253 Mon Sep 17 00:00:00 2001 From: Edwin Jakobs Date: Sun, 5 Jan 2020 22:38:55 +0100 Subject: [PATCH] Added more presets to orx-runway --- orx-runway/src/main/kotlin/Presets.kt | 40 +++++++++++++++++-- orx-runway/src/main/kotlin/RunwayHttp.kt | 51 ++++++++++++++---------- 2 files changed, 67 insertions(+), 24 deletions(-) diff --git a/orx-runway/src/main/kotlin/Presets.kt b/orx-runway/src/main/kotlin/Presets.kt index 5b705e57..43285fe0 100644 --- a/orx-runway/src/main/kotlin/Presets.kt +++ b/orx-runway/src/main/kotlin/Presets.kt @@ -3,29 +3,61 @@ package org.openrndr.extra.runway import com.google.gson.annotations.SerializedName // -- AttnGAN -class CaptionRequest(val caption: String) -class CaptionResult(val result: String) +class AttnGANRequest(val caption: String) + +class AttnGANResult(val result: String) // -- BDCN class BdcnRequest(val input_image: String) + class BdcnResult(val output_image: String) // -- BigBiGAN class BigBiGANQuery(@SerializedName("input_image") val inputImage: String) + class BigBiGANResult(@SerializedName("output_image") val outputImage: String) // -- SPADE-COCO class SpadeCocoRequest(val semantic_map: String) + class SpadeCocoResult(val output: String) // -- GPT-2 -class Gpt2Request(val prompt: String) +class Gpt2Request(val prompt: String, val seed: Int = 0, @SerializedName("sequence_length") val sequenceLength: Int = 128) + class Gpt2Result(val text: String) // -- im2txt class Im2txtRequest(val image: String) + class Im2txtResult(val caption: String) // -- PSENet class PsenetRequest(@SerializedName("input_image") val inputImage: String) -class PsenetResult(val bboxes: Array>) \ No newline at end of file + +class PsenetResult(val bboxes: Array>) + +// -- Face landmarks +class FaceLandmarksRequest(val photo: String) + +class FaceLandmarksResponse(val points: List>, val labels: List) + +// -- StyleGAN + +/** + * StyleGAN request + * @param z a list of 512 doubles + */ +class StyleGANRequest(val z: List, val truncation: Double = 1.0) + +class StyleGANResponse(val image: String) + +// -- DeOldify +class DeOldifyRequest(val image: String, val renderFactor: Int = 20) + +class DeOldifyResponse(val image: String) + +// -- DenseCap + +class DenseCapRequest(val image: String, @SerializedName("max_detections") val maxDetections: Int = 10) +class DenseCapResponse(val bboxes: List>, val classes: List, val scores: List) \ No newline at end of file diff --git a/orx-runway/src/main/kotlin/RunwayHttp.kt b/orx-runway/src/main/kotlin/RunwayHttp.kt index a5928e2a..7bed08e5 100644 --- a/orx-runway/src/main/kotlin/RunwayHttp.kt +++ b/orx-runway/src/main/kotlin/RunwayHttp.kt @@ -5,8 +5,11 @@ import org.openrndr.draw.ColorBuffer import org.openrndr.draw.FileFormat import java.io.ByteArrayInputStream import java.io.File +import java.io.IOException import java.net.HttpURLConnection +import java.net.SocketTimeoutException import java.net.URL +import java.net.UnknownHostException import java.util.* /** @@ -29,7 +32,7 @@ fun ColorBuffer.toData(format: FileFormat = FileFormat.JPG): String { fun ColorBuffer.Companion.fromData(data: String): ColorBuffer { val decoder = Base64.getDecoder() val commaIndex = data.indexOf(",") - val imageData = decoder.decode(data.drop(commaIndex+1)) + val imageData = decoder.decode(data.drop(commaIndex + 1)) ByteArrayInputStream(imageData).use { return ColorBuffer.fromStream(it) @@ -41,24 +44,32 @@ fun ColorBuffer.Companion.fromData(data: String): ColorBuffer { * @param target url string e.g. http://localhost:8000/query */ inline fun runwayQuery(target: String, query: Q): R { - val queryJson = Gson().toJson(query) - val connection = URL(target).openConnection() as HttpURLConnection - with(connection) { - doOutput = true - connectTimeout = 1_000 - readTimeout = 200_000 - requestMethod = "POST" - setRequestProperty("Content-Type", "application/json") - setRequestProperty("Accept", "application/json") + + try { + val queryJson = Gson().toJson(query) + val connection = URL(target).openConnection() as HttpURLConnection + //with(connection) { + connection.doOutput = true + connection.connectTimeout = 1_000 + connection.readTimeout = 200_000 + connection.requestMethod = "POST" + connection.setRequestProperty("Content-Type", "application/json") + connection.setRequestProperty("Accept", "application/json") + //} + + val outputStream = connection.outputStream + outputStream.write(queryJson.toByteArray()) + outputStream.flush() + + val inputStream = connection.inputStream + val responseJson = String(inputStream.readBytes()) + println(responseJson) + inputStream.close() + connection.disconnect() + return Gson().fromJson(responseJson, R::class.java) + } catch (e: SocketTimeoutException) { + error("RunwayML connection timed out. Check if Runway and model are running.") + } catch (e: UnknownHostException) { + error("Runway host not found. Check if Runway and model are running.") } - - val outputStream = connection.outputStream - outputStream.write(queryJson.toByteArray()) - outputStream.flush() - - val inputStream = connection.inputStream - val responseJson = String(inputStream.readBytes()) - inputStream.close() - connection.disconnect() - return Gson().fromJson(responseJson, R::class.java) } \ No newline at end of file