Skip to content

Commit

Permalink
Add a Tensor.device function.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed May 9, 2020
1 parent 4f4748f commit 434c2f2
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/stubs/torch_bindings.ml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ module C (F : Cstubs.FOREIGN) = struct
(* kind *)
@-> returning t)

let device = foreign "at_device" (t @-> returning int)
let defined = foreign "at_defined" (t @-> returning bool)
let num_dims = foreign "at_dim" (t @-> returning int)
let shape = foreign "at_shape" (t @-> ptr int (* dims *) @-> returning void)
Expand Down
5 changes: 3 additions & 2 deletions src/tests/tensor_tests.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ let%expect_test _ =
t.%.{[ 1; 1 ]} <- 42.0;
t.%.{[ 3; 0 ]} <- 1.337;
for i = 0 to 3 do
Stdio.printf "%f %f\n" (t.%.{[ i; 0 ]}) (t.%.{[ i; 1 ]})
Stdio.printf "%f %f\n" t.%.{[ i; 0 ]} t.%.{[ i; 1 ]}
done;
[%expect
{|
Expand Down Expand Up @@ -144,4 +144,5 @@ let%expect_test _ =
|> List.iter ~f:(fun t -> Tensor.to_int1_exn t |> Stdio.printf !"%{sexp:int array}\n");
[%expect {|
(3 1 4)
(1 5 9) |}]
(1 5 9) |}];
assert (Tensor.device t = Cpu)
2 changes: 2 additions & 0 deletions src/wrapper/device.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ let to_int = function
| Cuda i ->
if i < 0 then Printf.sprintf "negative index for cuda device" |> failwith;
i

let of_int i = if i < 0 then Cpu else Cuda i
1 change: 1 addition & 0 deletions src/wrapper/device.mli
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ type t =
| Cuda of int

val to_int : t -> int
val of_int : int -> t
8 changes: 8 additions & 0 deletions src/wrapper/torch_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ at::Device device_of_int(int d) {
return at::Device(at::kCUDA, /*index=*/d);
}

int at_device(tensor tensor) {
PROTECT (
auto device = tensor->device();
if (device.is_cpu()) return -1;
return device.index();
)
}

tensor at_new_tensor() {
PROTECT(
return new torch::Tensor();
Expand Down
1 change: 1 addition & 0 deletions src/wrapper/torch_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ void at_copy_data(tensor tensor, void *vs, int64_t numel, int element_size_in_by
tensor at_float_vec(double *values, int value_len, int type);
tensor at_int_vec(int64_t *values, int value_len, int type);

int at_device(tensor);
int at_defined(tensor);
int at_dim(tensor);
void at_shape(tensor, int *);
Expand Down
1 change: 1 addition & 0 deletions src/wrapper/wrapper.ml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ module Tensor = struct
let min = min1
let copy_ t ~src = copy_ t src
let defined = defined
let device t = device t |> Device.of_int

let new_tensor () =
let t = new_tensor () in
Expand Down
1 change: 1 addition & 0 deletions src/wrapper/wrapper.mli
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ module Tensor : sig
val mean : t -> t
val argmax : ?dim:int -> ?keepdim:bool -> t -> t
val defined : t -> bool
val device : t -> Device.t
val copy_ : t -> src:t -> unit
val max : t -> t -> t
val min : t -> t -> t
Expand Down

0 comments on commit 434c2f2

Please sign in to comment.