Added more presets to orx-runway

This commit is contained in:
Edwin Jakobs
2020-01-05 22:38:55 +01:00
parent 18bda2e364
commit 49b34ed902
2 changed files with 67 additions and 24 deletions

View File

@@ -3,29 +3,61 @@ package org.openrndr.extra.runway
import com.google.gson.annotations.SerializedName import com.google.gson.annotations.SerializedName
// -- AttnGAN // -- AttnGAN
class CaptionRequest(val caption: String) class AttnGANRequest(val caption: String)
class CaptionResult(val result: String)
class AttnGANResult(val result: String)
// -- BDCN // -- BDCN
class BdcnRequest(val input_image: String) class BdcnRequest(val input_image: String)
class BdcnResult(val output_image: String) class BdcnResult(val output_image: String)
// -- BigBiGAN // -- BigBiGAN
class BigBiGANQuery(@SerializedName("input_image") val inputImage: String) class BigBiGANQuery(@SerializedName("input_image") val inputImage: String)
class BigBiGANResult(@SerializedName("output_image") val outputImage: String) class BigBiGANResult(@SerializedName("output_image") val outputImage: String)
// -- SPADE-COCO // -- SPADE-COCO
class SpadeCocoRequest(val semantic_map: String) class SpadeCocoRequest(val semantic_map: String)
class SpadeCocoResult(val output: String) class SpadeCocoResult(val output: String)
// -- GPT-2 // -- GPT-2
class Gpt2Request(val prompt: String) class Gpt2Request(val prompt: String, val seed: Int = 0, @SerializedName("sequence_length") val sequenceLength: Int = 128)
class Gpt2Result(val text: String) class Gpt2Result(val text: String)
// -- im2txt // -- im2txt
class Im2txtRequest(val image: String) class Im2txtRequest(val image: String)
class Im2txtResult(val caption: String) class Im2txtResult(val caption: String)
// -- PSENet // -- PSENet
class PsenetRequest(@SerializedName("input_image") val inputImage: String) class PsenetRequest(@SerializedName("input_image") val inputImage: String)
class PsenetResult(val bboxes: Array<Array<Double>>) class PsenetResult(val bboxes: Array<Array<Double>>)
// -- Face landmarks
class FaceLandmarksRequest(val photo: String)
class FaceLandmarksResponse(val points: List<List<Double>>, val labels: List<String>)
// -- StyleGAN
/**
* StyleGAN request
* @param z a list of 512 doubles
*/
class StyleGANRequest(val z: List<Double>, val truncation: Double = 1.0)
class StyleGANResponse(val image: String)
// -- DeOldify
class DeOldifyRequest(val image: String, val renderFactor: Int = 20)
class DeOldifyResponse(val image: String)
// -- DenseCap
class DenseCapRequest(val image: String, @SerializedName("max_detections") val maxDetections: Int = 10)
class DenseCapResponse(val bboxes: List<List<Double>>, val classes: List<String>, val scores: List<Double>)

View File

@@ -5,8 +5,11 @@ import org.openrndr.draw.ColorBuffer
import org.openrndr.draw.FileFormat import org.openrndr.draw.FileFormat
import java.io.ByteArrayInputStream import java.io.ByteArrayInputStream
import java.io.File import java.io.File
import java.io.IOException
import java.net.HttpURLConnection import java.net.HttpURLConnection
import java.net.SocketTimeoutException
import java.net.URL import java.net.URL
import java.net.UnknownHostException
import java.util.* import java.util.*
/** /**
@@ -41,16 +44,18 @@ fun ColorBuffer.Companion.fromData(data: String): ColorBuffer {
* @param target url string e.g. http://localhost:8000/query * @param target url string e.g. http://localhost:8000/query
*/ */
inline fun <Q, reified R> runwayQuery(target: String, query: Q): R { inline fun <Q, reified R> runwayQuery(target: String, query: Q): R {
try {
val queryJson = Gson().toJson(query) val queryJson = Gson().toJson(query)
val connection = URL(target).openConnection() as HttpURLConnection val connection = URL(target).openConnection() as HttpURLConnection
with(connection) { //with(connection) {
doOutput = true connection.doOutput = true
connectTimeout = 1_000 connection.connectTimeout = 1_000
readTimeout = 200_000 connection.readTimeout = 200_000
requestMethod = "POST" connection.requestMethod = "POST"
setRequestProperty("Content-Type", "application/json") connection.setRequestProperty("Content-Type", "application/json")
setRequestProperty("Accept", "application/json") connection.setRequestProperty("Accept", "application/json")
} //}
val outputStream = connection.outputStream val outputStream = connection.outputStream
outputStream.write(queryJson.toByteArray()) outputStream.write(queryJson.toByteArray())
@@ -58,7 +63,13 @@ inline fun <Q, reified R> runwayQuery(target: String, query: Q): R {
val inputStream = connection.inputStream val inputStream = connection.inputStream
val responseJson = String(inputStream.readBytes()) val responseJson = String(inputStream.readBytes())
println(responseJson)
inputStream.close() inputStream.close()
connection.disconnect() connection.disconnect()
return Gson().fromJson(responseJson, R::class.java) return Gson().fromJson(responseJson, R::class.java)
} catch (e: SocketTimeoutException) {
error("RunwayML connection timed out. Check if Runway and model are running.")
} catch (e: UnknownHostException) {
error("Runway host not found. Check if Runway and model are running.")
}
} }