diff --git a/src/stubs/torch_bindings.ml b/src/stubs/torch_bindings.ml index 3fe906b..f41b3d2 100644 --- a/src/stubs/torch_bindings.ml +++ b/src/stubs/torch_bindings.ml @@ -61,6 +61,7 @@ module C (F : Cstubs.FOREIGN) = struct (* kind *) @-> returning t) + let device = foreign "at_device" (t @-> returning int) let defined = foreign "at_defined" (t @-> returning bool) let num_dims = foreign "at_dim" (t @-> returning int) let shape = foreign "at_shape" (t @-> ptr int (* dims *) @-> returning void) diff --git a/src/tests/tensor_tests.ml b/src/tests/tensor_tests.ml index 9765c7f..b210af9 100644 --- a/src/tests/tensor_tests.ml +++ b/src/tests/tensor_tests.ml @@ -24,7 +24,7 @@ let%expect_test _ = t.%.{[ 1; 1 ]} <- 42.0; t.%.{[ 3; 0 ]} <- 1.337; for i = 0 to 3 do - Stdio.printf "%f %f\n" (t.%.{[ i; 0 ]}) (t.%.{[ i; 1 ]}) + Stdio.printf "%f %f\n" t.%.{[ i; 0 ]} t.%.{[ i; 1 ]} done; [%expect {| @@ -144,4 +144,5 @@ let%expect_test _ = |> List.iter ~f:(fun t -> Tensor.to_int1_exn t |> Stdio.printf !"%{sexp:int array}\n"); [%expect {| (3 1 4) - (1 5 9) |}] + (1 5 9) |}]; + assert (Tensor.device t = Cpu) diff --git a/src/wrapper/device.ml b/src/wrapper/device.ml index 4ad7fd5..5d82b1c 100644 --- a/src/wrapper/device.ml +++ b/src/wrapper/device.ml @@ -8,3 +8,5 @@ let to_int = function | Cuda i -> if i < 0 then Printf.sprintf "negative index for cuda device" |> failwith; i + +let of_int i = if i < 0 then Cpu else Cuda i diff --git a/src/wrapper/device.mli b/src/wrapper/device.mli index 7a0eb23..a1148ae 100644 --- a/src/wrapper/device.mli +++ b/src/wrapper/device.mli @@ -5,3 +5,4 @@ type t = | Cuda of int val to_int : t -> int +val of_int : int -> t diff --git a/src/wrapper/torch_api.cpp b/src/wrapper/torch_api.cpp index edca9ca..359cc3d 100644 --- a/src/wrapper/torch_api.cpp +++ b/src/wrapper/torch_api.cpp @@ -22,6 +22,14 @@ at::Device device_of_int(int d) { return at::Device(at::kCUDA, /*index=*/d); } +int at_device(tensor tensor) { + PROTECT ( + auto device = tensor->device(); + if (device.is_cpu()) return -1; + return device.index(); + ) +} + tensor at_new_tensor() { PROTECT( return new torch::Tensor(); diff --git a/src/wrapper/torch_api.h b/src/wrapper/torch_api.h index 1f6bd15..006df5a 100644 --- a/src/wrapper/torch_api.h +++ b/src/wrapper/torch_api.h @@ -30,6 +30,7 @@ void at_copy_data(tensor tensor, void *vs, int64_t numel, int element_size_in_by tensor at_float_vec(double *values, int value_len, int type); tensor at_int_vec(int64_t *values, int value_len, int type); +int at_device(tensor); int at_defined(tensor); int at_dim(tensor); void at_shape(tensor, int *); diff --git a/src/wrapper/wrapper.ml b/src/wrapper/wrapper.ml index 9e9c892..723b472 100644 --- a/src/wrapper/wrapper.ml +++ b/src/wrapper/wrapper.ml @@ -163,6 +163,7 @@ module Tensor = struct let min = min1 let copy_ t ~src = copy_ t src let defined = defined + let device t = device t |> Device.of_int let new_tensor () = let t = new_tensor () in diff --git a/src/wrapper/wrapper.mli b/src/wrapper/wrapper.mli index c05fcea..2e433c7 100644 --- a/src/wrapper/wrapper.mli +++ b/src/wrapper/wrapper.mli @@ -51,6 +51,7 @@ module Tensor : sig val mean : t -> t val argmax : ?dim:int -> ?keepdim:bool -> t -> t val defined : t -> bool + val device : t -> Device.t val copy_ : t -> src:t -> unit val max : t -> t -> t val min : t -> t -> t