{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "<a href=\"https://colab.research.google.com/github/jarredou/Music-Source-Separation-Training-Colab-Inference/blob/main/Music_Source_Separation_Training_(Colab_Inference)_CustomModel.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" ] }, { "cell_type": "markdown", "metadata": { "id": "dC6MxtLlx7vN" }, "source": [ "# Colab inference for ZFTurbo's [Music-Source-Separation-Training](https://github.com/ZFTurbo/Music-Source-Separation-Training/)\n", "\n", "\n", "<font size=1>*made by [jarredou](https://github.com/jarredou) & deton</font> \n", "[](https://ko-fi.com/Q5Q811R5YI) " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "vKOCPJkyw9yh" }, "outputs": [], "source": [ "#@markdown #Gdrive connection\n", "from google.colab import drive\n", "drive.mount('/content/drive')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "ScA4L7gmQEjM" }, "outputs": [], "source": [ "import base64\n", "#@markdown # Install\n", "\n", "%cd /content\n", "!git clone -b colab-inference https://github.com/jarredou/Music-Source-Separation-Training\n", "\n", "#generate new requirements.txt file for faster colab install\n", "req = 'IyB0b3JjaCAjPT0yLjAuMQ0KbnVtcHkNCnBhbmRhcw0Kc2NpcHkNCnNvdW5kZmlsZQ0KbWxfY29sbGVjdGlvbnMNCnRxZG0NCnNlZ21lbnRhdGlvbl9tb2RlbHNfcHl0b3JjaD09MC4zLjMNCnRpbW09PTAuOS4yDQphdWRpb21lbnRhdGlvbnM9PTAuMjQuMA0KcGVkYWxib2FyZD09MC44LjENCm9tZWdhY29uZj09Mi4yLjMNCmJlYXJ0eXBlPT0wLjE0LjENCnJvdGFyeV9lbWJlZGRpbmdfdG9yY2g9PTAuMy41DQplaW5vcHM9PTAuNi4xDQpsaWJyb3NhDQpkZW11Y3MgIz09NC4wLjANCiMgdHJhbnNmb3JtZXJzPT00LjM1LjANCnRvcmNobWV0cmljcz09MC4xMS40DQpzcGFmZT09MC4zLjINCnByb3RvYnVmPT0zLjIwLjMNCnRvcmNoX2F1ZGlvbWVudGF0aW9ucw0KYXN0ZXJvaWQ9PTAuNy4wDQphdXJhbG9zcw0KdG9yY2hzZWcNCg=='\n", "dec_req = base64.b64decode(req).decode('utf-8')\n", "f = open(\"Music-Source-Separation-Training/requirements.txt\", \"w\")\n", "f.write(dec_req)\n", "f.close()\n", "\n", "!mkdir '/content/Music-Source-Separation-Training/ckpts'\n", "\n", "print('Installing the dependencies... This will take few minutes')\n", "!pip install -r 'Music-Source-Separation-Training/requirements.txt' &> /dev/null\n", "\n", "print('Installation is done !')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GS-QezQ-RG64", "cellView": "form" }, "outputs": [], "source": [ "%cd '/content/Music-Source-Separation-Training/'\n", "import os\n", "import torch\n", "import yaml\n", "from urllib.parse import quote\n", "\n", "class IndentDumper(yaml.Dumper):\n", " def increase_indent(self, flow=False, indentless=False):\n", " return super(IndentDumper, self).increase_indent(flow, False)\n", "\n", "\n", "def tuple_constructor(loader, node):\n", " # Load the sequence of values from the YAML node\n", " values = loader.construct_sequence(node)\n", " # Return a tuple constructed from the sequence\n", " return tuple(values)\n", "\n", "# Register the constructor with PyYAML\n", "yaml.SafeLoader.add_constructor('tag:yaml.org,2002:python/tuple',\n", "tuple_constructor)\n", "\n", "\n", "\n", "def conf_edit(config_path, chunk_size, overlap):\n", " with open(config_path, 'r') as f:\n", " data = yaml.load(f, Loader=yaml.SafeLoader)\n", "\n", " # handle cases where 'use_amp' is missing from config:\n", " if 'use_amp' not in data.keys():\n", " data['training']['use_amp'] = True\n", "\n", " data['audio']['chunk_size'] = chunk_size\n", " data['inference']['num_overlap'] = overlap\n", "\n", " if data['inference']['batch_size'] == 1:\n", " data['inference']['batch_size'] = 2\n", "\n", " print(\"Using custom overlap and chunk_size values for roformer model:\")\n", " print(f\"overlap = {data['inference']['num_overlap']}\")\n", " print(f\"chunk_size = {data['audio']['chunk_size']}\")\n", " print(f\"batch_size = {data['inference']['batch_size']}\")\n", "\n", "\n", " with open(config_path, 'w') as f:\n", " yaml.dump(data, f, default_flow_style=False, sort_keys=False, Dumper=IndentDumper, allow_unicode=True)\n", "\n", "def download_file(url):\n", " # Encode the URL to handle spaces and special characters\n", " encoded_url = quote(url, safe=':/')\n", "\n", " path = 'ckpts'\n", " os.makedirs(path, exist_ok=True)\n", " filename = os.path.basename(encoded_url)\n", " file_path = os.path.join(path, filename)\n", "\n", " if os.path.exists(file_path):\n", " print(f\"File '{filename}' already exists at '{path}'.\")\n", " return\n", "\n", " try:\n", " response = torch.hub.download_url_to_file(encoded_url, file_path)\n", " print(f\"File '{filename}' downloaded successfully\")\n", " except Exception as e:\n", " print(f\"Error downloading file '{filename}' from '{url}': {e}\")\n", "\n", "\n", "\n", "\n", "\n", "#@markdown # Separation\n", "#@markdown #### Model config:\n", "config_url = 'https://huggingface.co/becruily/mel-band-roformer-instrumental/resolve/main/config_instrumental_becruily.yaml' #@param {type:\"string\"}\n", "ckpt_url = 'https://huggingface.co/becruily/mel-band-roformer-instrumental/resolve/main/mel_band_roformer_instrumental_becruily.ckpt' #@param {type:\"string\"}\n", "model_type = 'mel_band_roformer' #@param ['mdx23c','bs_roformer', 'mel_band_roformer', 'bandit', 'bandit_v2', 'scnet', 'apollo', 'htdemucs', 'segm_models', 'torchseg', 'bs_mamba2']\n", "#@markdown ---\n", "#@markdown #### Separation config:\n", "input_folder = '/content/drive/MyDrive/input' #@param {type:\"string\"}\n", "output_folder = '/content/drive/MyDrive/output' #@param {type:\"string\"}\n", "extract_instrumental = True #@param {type:\"boolean\"}\n", "export_format = 'flac PCM_16' #@param ['wav FLOAT', 'flac PCM_16', 'flac PCM_24']\n", "use_tta = False #@param {type:\"boolean\"}\n", "#@markdown ---\n", "#@markdown *Roformers custom config:*\n", "overlap = 2 #@param {type:\"slider\", min:2, max:40, step:1}\n", "chunk_size = \"485100\" #@param [352800, 485100] {allow-input: true}\n", "\n", "if export_format.startswith('flac'):\n", " flac_file = True\n", " pcm_type = export_format.split(' ')[1]\n", "else:\n", " flac_file = False\n", " pcm_type = None\n", "\n", "\n", "if config_url != '' and ckpt_url != '':\n", " config_filename = os.path.basename(config_url)\n", " ckpt_filename = os.path.basename(ckpt_url)\n", " print(config_filename, ckpt_filename)\n", " config_path = f'ckpts/{config_filename}'\n", " start_check_point = f'ckpts/{ckpt_filename}'\n", " download_file(config_url)\n", " download_file(ckpt_url)\n", " if \"roformer\" in model_type:\n", " conf_edit(config_path, int(chunk_size), overlap)\n", "\n", "!python inference.py \\\n", " --model_type {model_type} \\\n", " --config_path '{config_path}' \\\n", " --start_check_point '{start_check_point}' \\\n", " --input_folder '{input_folder}' \\\n", " --store_dir '{output_folder}' \\\n", " {('--extract_instrumental' if extract_instrumental else '')} \\\n", " {('--flac_file' if flac_file else '')} \\\n", " {('--use_tta' if use_tta else '')} \\\n", " {('--pcm_type ' + pcm_type if pcm_type else '')}" ] }, { "cell_type": "markdown", "source": [ "**INST-Mel-Roformer v1/1x/2** has switched output file names - files labelled as vocals are instrumentals (if you uncheck extract_instrumentals for v1e model, only one stem caled \"other\" will be rendered, and it will be instrumental.<br><br>\n", "**TTA** - results in longer separation time, \"it gives a little better SDR score but hard to tell if it's really audible in most cases\". <br> it “means \"test time augmentation\", (...) it will do 3 passes on the audio file instead of 1. 1 pass with be with original audio. 1 will be with inverted stereo (L becomes R, R become L). 1 will be with phase inverted and then results are averaged for final output. ” - jarredou\n", "<br><br>\n", "**Overlap** - higher means longer separation time. 4 is already balanced value, 2 is fast and some people still won't notice any difference. Normally there's not point going over 8.<br><br>\n", "If your separation can't start and \"Total files found: 0\" is shown, be aware that: <br>1) Input must be a path to a folder containing audio files, not direct path to an audio file<br> 2) The Colab is case aware - e.g. call your folder \"input\" not \"Input\".<br> 3) Check if your Google Drive mounting was executed correctly. Open file manager on the left to check if your drive folder is not empty. If it's the case, force remount with the following line:" ], "metadata": { "id": "fJ-4Ilx2XTKd" } }, { "cell_type": "code", "source": [ "drive.mount(\"/content/drive\", force_remount=True)" ], "metadata": { "id": "Tfv8v8jgihdQ" }, "execution_count": null, "outputs": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [], "include_colab_link": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }