diff --git a/orx-tensorflow/src/main/kotlin/Tensor.kt b/orx-tensorflow/src/main/kotlin/Tensor.kt index bc55525b..6cfa163a 100644 --- a/orx-tensorflow/src/main/kotlin/Tensor.kt +++ b/orx-tensorflow/src/main/kotlin/Tensor.kt @@ -36,12 +36,12 @@ fun ColorBuffer.copyTo(tensor: Tensor) { fun Tensor.copyTo(colorBuffer: ColorBuffer) { val s = shape() - require(s.numDimensions() == 2 || s.numDimensions() == 3) val components = when { + s.numDimensions() == 2 -> 1 s.numDimensions() == 3 -> s.size(2).toInt() s.numDimensions() == 4 -> s.size(3).toInt() - else -> 1 + else -> error("can't copy to colorbuffer from ${s.numDimensions()}D tensor") } val format = when (components) {