diff --git a/.gitignore b/.gitignore index 1d4cfa0..8a177e7 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,8 @@ venv.bak/ # mypy .mypy_cache/ + +# pixi environments +.pixi +pixi.lock +*.egg-info diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..267d2d4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,104 @@ +[project] +name = "leanrl" +requires-python = ">=3.8,<3.11" + +[tool.pixi.project] +channels = ["conda-forge","pytorch"] +platforms = ["osx-arm64", "linux-64", "osx-64", "win-64"] + +[tool.pixi.tasks] +dqn = "python leanrl/dqn.py --total-timesteps 256" +#dqn-torchcompile = "python leanrl/dqn_torchcompile.py --total-timesteps 256 --compile --cudagraphs" +dqn-jax = "python leanrl/dqn_jax.py --total-timesteps 256" +ppo = "python leanrl/ppo_continuous_action.py --num-steps 64 --total-timesteps 256" + +[tool.pixi.pypi-dependencies] +gym = "==0.23.1" + +[tool.pixi.dependencies] +tensorboard = ">=2.10.0,<3.0.0" +wandb = ">=0.13.11,<1.0.0" +pytorch = ">=2.0" +stable-baselines3 = "2.0.0" +gymnasium = ">=0.28.1" +moviepy = ">=1.0.3,<2.0.0" +pygame = ">=2.1,<2.2" +huggingface_hub = ">=0.11.1,<1.0.0" +rich = "<12.0" +tenacity = ">=8.2.2,<9.0.0" +tyro = ">=0.5.10,<1.0.0" +pyyaml = ">=6.0.1,<7.0.0" +numpy = ">=1.21.6" + +[tool.pixi.feature.tensordict.dependencies] +tensordict = "*" # latest version not available in conda for osx-arm64 + +[tool.pixi.feature.atari.dependencies] +ale-py = "0.8.1" + +[tool.pixi.feature.atari.pypi-dependencies] +AutoROM = {version = ">=0.4.2,<0.5.0", extras = ["accept-rom-license"]} + +[tool.pixi.feature.opencv.dependencies] +opencv-python = ">=4.6.0.66" + +[tool.pixi.feature.procgen.dependencies] +procgen = ">=0.10.7" + +[tool.pixi.feature.pytest.dependencies] +pytest = ">=7.1.3" + +[tool.pixi.feature.mujoco.dependencies] +mujoco = "<=2.3.3" + +[tool.pixi.feature.imageio.dependencies] +imageio = ">=2.14.1" + +[tool.pixi.feature.docs.dependencies] +mkdocs-material = ">=8.4.3" +markdown-include = ">=0.7.0" + +[tool.pixi.feature.openrlbenchmark.dependencies] +openrlbenchmark = ">=0.1.1b4" + +[tool.pixi.feature.jax.pypi-dependencies] +jax = ">=0.4.11,<0.5" +jaxlib = ">=0.4.7,<0.5" +flax = ">=0.6.8,<0.10" + +[tool.pixi.feature.optuna.dependencies] +optuna = ">=3.0.1" +optuna-dashboard = ">=0.7.2" + +[tool.pixi.feature.envpool.dependencies] +envpool = ">=0.6.4" + +[tool.pixi.feature.multi_agent.dependencies] +PettingZoo = "1.18.1" +SuperSuit = "3.4.0" +multi-agent-ale-py = "0.1.11" + +[tool.pixi.feature.aws.dependencies] +boto3 = ">=1.24.70" +awscli = ">=1.31.0" + +[tool.pixi.feature.shimmy.dependencies] +shimmy = ">=1.1.0" + +[tool.pixi.feature.dm_control.dependencies] +dm-control = ">=1.0.10" + +[tool.pixi.feature.h5py.dependencies] +h5py = ">=3.7.0" + +[tool.pixi.feature.optax.dependencies] +optax = "0.1.4" + +[tool.pixi.feature.chex.dependencies] +chex = "0.1.5" + +[tool.pixi.environments] +dqn = [] +#dqn-torchcompile = ["tensordict"] +dqn-jax = ["jax"] +ppo = ["mujoco"]