Skip to content

Commit

Permalink
Update for libtorch 1.9 (#57)
Browse files Browse the repository at this point in the history
* Start adapting the code generation for libtorch 1.9.

* Get the tests to run.

* Get the examples to compile.

* Update to the ocaml github workflow v2.
  • Loading branch information
LaurentMazare authored Aug 15, 2021
1 parent 6c5502d commit 73f6dab
Show file tree
Hide file tree
Showing 23 changed files with 7,504 additions and 3,764 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
uses: actions/checkout@v2

- name: Use OCaml ${{ matrix.ocaml-version }}
uses: avsm/setup-ocaml@v1
uses: avsm/setup-ocaml@v2
with:
ocaml-version: ${{ matrix.ocaml-version }}

Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ differentiation.

These bindings use the [PyTorch C++ API](https://pytorch.org/cppdocs/) and are
mostly automatically generated. The current GitHub tip and the opam package v0.7
corresponds to PyTorch **v1.8.0**.
corresponds to PyTorch **v1.9.0**.

On Linux note that you will need the PyTorch version using the cxx11 abi
[cpu version](https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.8.0%2Bcpu.zip),
[cuda 10.2 version](https://download.pytorch.org/libtorch/cu102/libtorch-cxx11-abi-shared-with-deps-1.8.0.zip).
[cpu version](https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.9.0%2Bcpu.zip),
[cuda 10.2 version](https://download.pytorch.org/libtorch/cu102/libtorch-cxx11-abi-shared-with-deps-1.9.0.zip).

## Opam Installation

Expand Down Expand Up @@ -152,8 +152,8 @@ This alternative way to install __ocaml-torch__ could be useful to run with GPU
acceleration enabled.

The libtorch library can be downloaded from the [PyTorch
website](https://pytorch.org/resources) ([1.8.0 cpu
version](https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.8.0+cpu.zip)).
website](https://pytorch.org/resources) ([1.9.0 cpu
version](https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.9.0+cpu.zip)).

Download and extract the libtorch library then to build all the examples run:

Expand Down
2 changes: 1 addition & 1 deletion examples/gan/gan_stability.ml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ let grad2 d_out x_in =
in
Tensor.(grad_dout * grad_dout)
|> Tensor.view ~size:[ batch_size; -1 ]
|> Tensor.sum1 ~dim:[ 1 ] ~keepdim:false ~dtype:(T Float)
|> Tensor.sum_dim_intlist ~dim:[ 1 ] ~keepdim:false ~dtype:(T Float)

let () =
let module Sys = Caml.Sys in
Expand Down
2 changes: 1 addition & 1 deletion examples/gan/progressive_growing_gan.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ let leaky_relu xs = Tensor.(max xs (xs * f 0.2))

let pixel_norm xs =
Tensor.(
xs / (sqrt (mean1 (xs * xs) ~dim:[ 1 ] ~keepdim:true ~dtype:(T Float)) + f 1e-8))
xs / (sqrt (mean_dim (xs * xs) ~dim:[ 1 ] ~keepdim:true ~dtype:(T Float)) + f 1e-8))

let w_scale_layer vs ~size:sz =
let vs = Var_store.sub vs "wscale" in
Expand Down
4 changes: 3 additions & 1 deletion examples/reinforcement-learning/a2c.ml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ let train ~device =
in
let dist_entropy =
Tensor.(
~-(log_probs * probs) |> sum1 ~dim:[ -1 ] ~keepdim:false ~dtype:(T Float) |> mean)
~-(log_probs * probs)
|> sum_dim_intlist ~dim:[ -1 ] ~keepdim:false ~dtype:(T Float)
|> mean)
in
let advantages =
let returns =
Expand Down
6 changes: 4 additions & 2 deletions examples/reinforcement-learning/dqn.ml
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,13 @@ end = struct
~dim:1
~index:(Tensor.unsqueeze actions ~dim:1)
~sparse_grad:false
|> Tensor.squeeze1 ~dim:1
|> Tensor.squeeze_dim ~dim:1
in
let next_qvalues =
Tensor.no_grad (fun () ->
Layer.forward t.model next_states |> Tensor.max2 ~dim:1 ~keepdim:false |> fst)
Layer.forward t.model next_states
|> Tensor.max_dim ~dim:1 ~keepdim:false
|> fst)
in
let expected_qvalues = Tensor.(rewards + (f t.gamma * next_qvalues * continue)) in
let loss = Tensor.mse_loss qvalues expected_qvalues in
Expand Down
6 changes: 3 additions & 3 deletions examples/reinforcement-learning/dqn_atari.ml
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ end = struct
~dim:1
~index:(Tensor.unsqueeze actions ~dim:1)
~sparse_grad:false
|> Tensor.squeeze1 ~dim:1
|> Tensor.squeeze_dim ~dim:1
in
let next_qvalues =
Tensor.no_grad (fun () ->
Expand All @@ -176,10 +176,10 @@ end = struct
in
Layer.forward t.target_model next_states
|> Tensor.gather ~dim:1 ~index:actions ~sparse_grad:false
|> Tensor.squeeze1 ~dim:1)
|> Tensor.squeeze_dim ~dim:1)
else
Layer.forward t.target_model next_states
|> Tensor.max2 ~dim:1 ~keepdim:false
|> Tensor.max_dim ~dim:1 ~keepdim:false
|> fst)
in
let expected_qvalues = Tensor.(rewards + (f t.gamma * next_qvalues * continue)) in
Expand Down
4 changes: 2 additions & 2 deletions examples/reinforcement-learning/dqn_pong.ml
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,12 @@ end = struct
~dim:1
~index:(Tensor.unsqueeze actions ~dim:1)
~sparse_grad:false
|> Tensor.squeeze1 ~dim:1
|> Tensor.squeeze_dim ~dim:1
in
let next_qvalues =
Tensor.no_grad (fun () ->
Layer.forward t.target_model next_states
|> Tensor.max2 ~dim:1 ~keepdim:false
|> Tensor.max_dim ~dim:1 ~keepdim:false
|> fst)
in
let expected_qvalues = Tensor.(rewards + (f t.gamma * next_qvalues * continue)) in
Expand Down
2 changes: 1 addition & 1 deletion examples/reinforcement-learning/policy_gradient.ml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ let () =
let logits = model (Tensor.stack acc.acc_obs ~dim:0) in
let log_probs =
Tensor.(
sum1
sum_dim_intlist
(action_mask * log_softmax logits ~dim:1 ~dtype:(T Float))
~dim:[ 1 ]
~keepdim:false
Expand Down
2 changes: 1 addition & 1 deletion examples/reinforcement-learning/ppo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ let train ~device =
let dist_entropy =
Tensor.(
~-(log_probs * probs)
|> sum1 ~dim:[ -1 ] ~keepdim:false ~dtype:(T Float)
|> sum_dim_intlist ~dim:[ -1 ] ~keepdim:false ~dtype:(T Float)
|> mean)
in
let advantages = Tensor.( - ) (Tensor.to_device returns ~device) critic in
Expand Down
2 changes: 1 addition & 1 deletion examples/reinforcement-learning/rollout.ml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ let run t ~model =
Tensor.(t.sum_rewards *= masks);
let obs = Frame_stack.update t.frame_stack obs ~masks in
set s_actions s actions;
set s_values s (critic |> Tensor.squeeze1 ~dim:(-1));
set s_values s (critic |> Tensor.squeeze_dim ~dim:(-1));
set t.s_states (s + 1) obs;
set s_rewards s reward;
set s_masks s masks
Expand Down
2 changes: 1 addition & 1 deletion examples/translation/seq2seq.ml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ module Dec_attn : Decoder = struct
~dim:1
in
let attn_applied =
Tensor.bmm attn_weights ~mat2:enc_outputs |> Tensor.squeeze1 ~dim:1
Tensor.bmm attn_weights ~mat2:enc_outputs |> Tensor.squeeze_dim ~dim:1
in
let output =
Tensor.cat [ embedded; attn_applied ] ~dim:1
Expand Down
69 changes: 49 additions & 20 deletions src/gen/gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ let excluded_functions =
; "retain_grad"
; "_validate_sparse_coo_tensor_args"
; "count_nonzero"
; "_assert_async"
; "gradient"
; "linalg_vector_norm"
; "linalg_vector_norm_out"
]

let no_tensor_options =
Expand All @@ -44,7 +48,7 @@ let no_tensor_options =
; "randn_like"
]

let excluded_prefixes = [ "thnn_"; "th_"; "_foreach"; "_amp_foreach" ]
let excluded_prefixes = [ "thnn_"; "th_"; "_foreach"; "_amp_foreach"; "linalg_norm" ]
let excluded_suffixes = [ "_forward"; "_forward_out" ]
let yaml_error yaml ~msg = failwith [%string "%{msg}, %{Yaml.to_string_exn yaml}"]

Expand Down Expand Up @@ -114,6 +118,8 @@ module Func = struct

type t =
{ name : string
; operator_name : string
; overload_name : string
; args : arg list
; returns : (* number of tensors that are returned *)
[ `fixed of int | `dynamic ]
Expand All @@ -125,15 +131,14 @@ module Func = struct
| "bool" -> Some Bool
| "int64_t" -> Some Int64
| "double" -> Some Double
| "booltensor" | "indextensor" | "tensor" ->
Some (if is_nullable then TensorOption else Tensor)
| "tensoroptions" -> Some TensorOptions
| "intarrayref" | "intlist" -> Some IntList
| "const c10::list<c10::optional<tensor>> &" -> Some TensorOptList
| "tensorlist" -> Some TensorList
| "device" -> Some Device
| "scalar" -> Some Scalar
| "scalartype" -> Some ScalarType
| "at::tensor" -> Some (if is_nullable then TensorOption else Tensor)
| "at::tensoroptions" -> Some TensorOptions
| "at::intarrayref" | "intlist" -> Some IntList
| "const c10::list<c10::optional<at::tensor>> &" -> Some TensorOptList
| "at::tensorlist" -> Some TensorList
| "at::device" -> Some Device
| "at::scalar" | "const at::scalar &" -> Some Scalar
| "at::scalartype" -> Some ScalarType
| "std::string" -> Some String
| _ -> None

Expand Down Expand Up @@ -263,6 +268,8 @@ let read_yaml filename =
List.filter_map funcs ~f:(fun yaml ->
let map = extract_map yaml in
let name = Map.find_exn map "name" |> extract_string in
let operator_name = Map.find_exn map "operator_name" |> extract_string in
let overload_name = Map.find_exn map "overload_name" |> extract_string in
let deprecated = Map.find_exn map "deprecated" |> extract_bool in
let method_of =
Map.find_exn map "method_of" |> extract_list |> List.map ~f:extract_string
Expand All @@ -272,9 +279,7 @@ let read_yaml filename =
let is_tensor returns =
let returns = extract_map returns in
let return_type = Map.find_exn returns "dynamic_type" |> extract_string in
String.( = ) return_type "Tensor"
|| String.( = ) return_type "BoolTensor"
|| String.( = ) return_type "IndexTensor"
String.( = ) return_type "at::Tensor"
in
let returns = Map.find_exn map "returns" |> extract_list in
if List.for_all returns ~f:is_tensor
Expand All @@ -285,7 +290,7 @@ let read_yaml filename =
let return_type =
Map.find_exn (extract_map returns) "dynamic_type" |> extract_string
in
if String.( = ) return_type "TensorList"
if String.( = ) return_type "at::TensorList"
|| String.( = )
return_type
"dynamic_type: const c10::List<c10::optional<Tensor>> &"
Expand Down Expand Up @@ -336,7 +341,7 @@ let read_yaml filename =
then None
else raise Not_a_simple_arg)
in
Some { Func.name; args; returns; kind }
Some { Func.name; operator_name; overload_name; args; returns; kind }
with
| Not_a_simple_arg -> None)
else None)
Expand Down Expand Up @@ -492,7 +497,15 @@ let write_wrapper funcs filename =
pi "end"))

let methods =
let c name args = { Func.name; args; returns = `fixed 1; kind = `method_ } in
let c name args =
{ Func.name
; operator_name = name
; overload_name = ""
; args
; returns = `fixed 1
; kind = `method_
}
in
let ca arg_name arg_type = { Func.arg_name; arg_type; default_value = None } in
[ c "grad" [ ca "self" Tensor ]
; c "set_requires_grad" [ ca "self" Tensor; ca "r" Bool ]
Expand All @@ -506,18 +519,34 @@ let run ~yaml_filename ~cpp_filename ~stubs_filename ~wrapper_filename =
printf "Generating code for %d functions.\n%!" (List.length funcs);
(* Generate some unique names for overloaded functions. *)
let funcs =
List.map funcs ~f:(fun func -> String.lowercase func.name, func)
List.map funcs ~f:(fun func -> String.lowercase func.operator_name, func)
|> Map.of_alist_multi (module String)
|> Map.to_alist
|> List.concat_map ~f:(fun (name, funcs) ->
match funcs with
| [] -> assert false
| [ func ] -> [ name, func ]
| funcs ->
let has_empty_overload =
List.exists funcs ~f:(fun (func : Func.t) ->
String.is_empty func.overload_name)
in
List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) ->
Int.compare (List.length f1.args) (List.length f2.args))
|> List.mapi ~f:(fun i func ->
(if i = 0 then name else Printf.sprintf "%s%d" name i), func))
match Int.compare (String.length f1.name) (String.length f2.name) with
| 0 -> Int.compare (List.length f1.args) (List.length f2.args)
| cmp -> cmp)
|> List.mapi ~f:(fun index (func : Func.t) ->
let operator_name = String.lowercase func.operator_name in
let overload_name = String.lowercase func.overload_name in
let name =
if String.is_empty overload_name
|| (index = 0 && not has_empty_overload)
then operator_name
else if String.is_suffix operator_name ~suffix:"_"
then operator_name ^ overload_name ^ "_"
else operator_name ^ "_" ^ overload_name
in
name, func))
|> Map.of_alist_exn (module String)
in
write_cpp funcs cpp_filename;
Expand Down
Loading

0 comments on commit 73f6dab

Please sign in to comment.