Added more presets to orx-runway
This commit is contained in:
@@ -3,29 +3,61 @@ package org.openrndr.extra.runway
|
|||||||
import com.google.gson.annotations.SerializedName
|
import com.google.gson.annotations.SerializedName
|
||||||
|
|
||||||
// -- AttnGAN
|
// -- AttnGAN
|
||||||
class CaptionRequest(val caption: String)
|
class AttnGANRequest(val caption: String)
|
||||||
class CaptionResult(val result: String)
|
|
||||||
|
class AttnGANResult(val result: String)
|
||||||
|
|
||||||
// -- BDCN
|
// -- BDCN
|
||||||
class BdcnRequest(val input_image: String)
|
class BdcnRequest(val input_image: String)
|
||||||
|
|
||||||
class BdcnResult(val output_image: String)
|
class BdcnResult(val output_image: String)
|
||||||
|
|
||||||
// -- BigBiGAN
|
// -- BigBiGAN
|
||||||
class BigBiGANQuery(@SerializedName("input_image") val inputImage: String)
|
class BigBiGANQuery(@SerializedName("input_image") val inputImage: String)
|
||||||
|
|
||||||
class BigBiGANResult(@SerializedName("output_image") val outputImage: String)
|
class BigBiGANResult(@SerializedName("output_image") val outputImage: String)
|
||||||
|
|
||||||
// -- SPADE-COCO
|
// -- SPADE-COCO
|
||||||
class SpadeCocoRequest(val semantic_map: String)
|
class SpadeCocoRequest(val semantic_map: String)
|
||||||
|
|
||||||
class SpadeCocoResult(val output: String)
|
class SpadeCocoResult(val output: String)
|
||||||
|
|
||||||
// -- GPT-2
|
// -- GPT-2
|
||||||
class Gpt2Request(val prompt: String)
|
class Gpt2Request(val prompt: String, val seed: Int = 0, @SerializedName("sequence_length") val sequenceLength: Int = 128)
|
||||||
|
|
||||||
class Gpt2Result(val text: String)
|
class Gpt2Result(val text: String)
|
||||||
|
|
||||||
// -- im2txt
|
// -- im2txt
|
||||||
class Im2txtRequest(val image: String)
|
class Im2txtRequest(val image: String)
|
||||||
|
|
||||||
class Im2txtResult(val caption: String)
|
class Im2txtResult(val caption: String)
|
||||||
|
|
||||||
// -- PSENet
|
// -- PSENet
|
||||||
class PsenetRequest(@SerializedName("input_image") val inputImage: String)
|
class PsenetRequest(@SerializedName("input_image") val inputImage: String)
|
||||||
class PsenetResult(val bboxes: Array<Array<Double>>)
|
|
||||||
|
class PsenetResult(val bboxes: Array<Array<Double>>)
|
||||||
|
|
||||||
|
// -- Face landmarks
|
||||||
|
class FaceLandmarksRequest(val photo: String)
|
||||||
|
|
||||||
|
class FaceLandmarksResponse(val points: List<List<Double>>, val labels: List<String>)
|
||||||
|
|
||||||
|
// -- StyleGAN
|
||||||
|
|
||||||
|
/**
|
||||||
|
* StyleGAN request
|
||||||
|
* @param z a list of 512 doubles
|
||||||
|
*/
|
||||||
|
class StyleGANRequest(val z: List<Double>, val truncation: Double = 1.0)
|
||||||
|
|
||||||
|
class StyleGANResponse(val image: String)
|
||||||
|
|
||||||
|
// -- DeOldify
|
||||||
|
class DeOldifyRequest(val image: String, val renderFactor: Int = 20)
|
||||||
|
|
||||||
|
class DeOldifyResponse(val image: String)
|
||||||
|
|
||||||
|
// -- DenseCap
|
||||||
|
|
||||||
|
class DenseCapRequest(val image: String, @SerializedName("max_detections") val maxDetections: Int = 10)
|
||||||
|
class DenseCapResponse(val bboxes: List<List<Double>>, val classes: List<String>, val scores: List<Double>)
|
||||||
@@ -5,8 +5,11 @@ import org.openrndr.draw.ColorBuffer
|
|||||||
import org.openrndr.draw.FileFormat
|
import org.openrndr.draw.FileFormat
|
||||||
import java.io.ByteArrayInputStream
|
import java.io.ByteArrayInputStream
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
import java.io.IOException
|
||||||
import java.net.HttpURLConnection
|
import java.net.HttpURLConnection
|
||||||
|
import java.net.SocketTimeoutException
|
||||||
import java.net.URL
|
import java.net.URL
|
||||||
|
import java.net.UnknownHostException
|
||||||
import java.util.*
|
import java.util.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -29,7 +32,7 @@ fun ColorBuffer.toData(format: FileFormat = FileFormat.JPG): String {
|
|||||||
fun ColorBuffer.Companion.fromData(data: String): ColorBuffer {
|
fun ColorBuffer.Companion.fromData(data: String): ColorBuffer {
|
||||||
val decoder = Base64.getDecoder()
|
val decoder = Base64.getDecoder()
|
||||||
val commaIndex = data.indexOf(",")
|
val commaIndex = data.indexOf(",")
|
||||||
val imageData = decoder.decode(data.drop(commaIndex+1))
|
val imageData = decoder.decode(data.drop(commaIndex + 1))
|
||||||
|
|
||||||
ByteArrayInputStream(imageData).use {
|
ByteArrayInputStream(imageData).use {
|
||||||
return ColorBuffer.fromStream(it)
|
return ColorBuffer.fromStream(it)
|
||||||
@@ -41,24 +44,32 @@ fun ColorBuffer.Companion.fromData(data: String): ColorBuffer {
|
|||||||
* @param target url string e.g. http://localhost:8000/query
|
* @param target url string e.g. http://localhost:8000/query
|
||||||
*/
|
*/
|
||||||
inline fun <Q, reified R> runwayQuery(target: String, query: Q): R {
|
inline fun <Q, reified R> runwayQuery(target: String, query: Q): R {
|
||||||
val queryJson = Gson().toJson(query)
|
|
||||||
val connection = URL(target).openConnection() as HttpURLConnection
|
try {
|
||||||
with(connection) {
|
val queryJson = Gson().toJson(query)
|
||||||
doOutput = true
|
val connection = URL(target).openConnection() as HttpURLConnection
|
||||||
connectTimeout = 1_000
|
//with(connection) {
|
||||||
readTimeout = 200_000
|
connection.doOutput = true
|
||||||
requestMethod = "POST"
|
connection.connectTimeout = 1_000
|
||||||
setRequestProperty("Content-Type", "application/json")
|
connection.readTimeout = 200_000
|
||||||
setRequestProperty("Accept", "application/json")
|
connection.requestMethod = "POST"
|
||||||
|
connection.setRequestProperty("Content-Type", "application/json")
|
||||||
|
connection.setRequestProperty("Accept", "application/json")
|
||||||
|
//}
|
||||||
|
|
||||||
|
val outputStream = connection.outputStream
|
||||||
|
outputStream.write(queryJson.toByteArray())
|
||||||
|
outputStream.flush()
|
||||||
|
|
||||||
|
val inputStream = connection.inputStream
|
||||||
|
val responseJson = String(inputStream.readBytes())
|
||||||
|
println(responseJson)
|
||||||
|
inputStream.close()
|
||||||
|
connection.disconnect()
|
||||||
|
return Gson().fromJson(responseJson, R::class.java)
|
||||||
|
} catch (e: SocketTimeoutException) {
|
||||||
|
error("RunwayML connection timed out. Check if Runway and model are running.")
|
||||||
|
} catch (e: UnknownHostException) {
|
||||||
|
error("Runway host not found. Check if Runway and model are running.")
|
||||||
}
|
}
|
||||||
|
|
||||||
val outputStream = connection.outputStream
|
|
||||||
outputStream.write(queryJson.toByteArray())
|
|
||||||
outputStream.flush()
|
|
||||||
|
|
||||||
val inputStream = connection.inputStream
|
|
||||||
val responseJson = String(inputStream.readBytes())
|
|
||||||
inputStream.close()
|
|
||||||
connection.disconnect()
|
|
||||||
return Gson().fromJson(responseJson, R::class.java)
|
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user