{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# ESPNET 2 pass SLU Demonstration\n",
        "\n",
        "This notebook provides a demonstration of the Two Pass End-to-End Spoken Language Understanding model\n",
        "\n",
        "Paper Link: https://arxiv.org/abs/2207.06670\n",
        "\n",
        "ESPnet2-SLU: https://github.com/espnet/espnet/tree/master/egs2/TEMPLATE/slu1\n",
        "\n",
        "Author: Siddhant Arora"
      ],
      "metadata": {
        "id": "dlicd4tC4xjp"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6Qvsfvkm4b5k"
      },
      "outputs": [],
      "source": [
        "! python -m pip install transformers\n",
        "!git clone https://github.com/espnet/espnet /espnet\n",
        "!pip install /espnet\n",
        "%pip install -q espnet_model_zoo\n",
        "%pip install fairseq@git+https://github.com//pytorch/fairseq.git@f2146bdc7abf293186de9449bfa2272775e39e1d#egg=fairseq"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Download Audio File"
      ],
      "metadata": {
        "id": "AwRW0VP5nwwJ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# !gdown --id 1LxoxCoFgx3u8CvKb1loybGFtArKKPcAH -O /content/audio_file.wav\n",
        "!gdown --id 18ANT62ittt7Ai2E8bQRlvT0ZVXXsf1eE -O /content/audio_file.wav"
      ],
      "metadata": {
        "id": "r1oiDIEon9dv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "\n",
        "import soundfile\n",
        "from IPython.display import display, Audio\n",
        "mixwav_mc, sr = soundfile.read(\"/content/audio_file.wav\")\n",
        "display(Audio(mixwav_mc.T, rate=sr))"
      ],
      "metadata": {
        "id": "lMy4Vwoc0MPo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Download and Load pretrained First Pass Model"
      ],
      "metadata": {
        "id": "x526Ibtjnv0N"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!git lfs clone https://huggingface.co/espnet/siddhana_slurp_new_asr_train_asr_conformer_raw_en_word_valid.acc.ave_10best /content/slurp_first_pass_model"
      ],
      "metadata": {
        "id": "aQNFXcMuo4iJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from espnet2.bin.asr_inference import Speech2Text\n",
        "speech2text_slurp = Speech2Text.from_pretrained(\n",
        "    asr_train_config=\"/content/slurp_first_pass_model/exp/asr_train_asr_conformer_raw_en_word/config.yaml\",\n",
        "    asr_model_file=\"/content/slurp_first_pass_model/exp/asr_train_asr_conformer_raw_en_word/valid.acc.ave_10best.pth\",\n",
        "    nbest=1,\n",
        ")"
      ],
      "metadata": {
        "id": "9Avmes1H1Xgo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "nbests_orig = speech2text_slurp(mixwav_mc)\n",
        "text, *_ = nbests_orig[0]\n",
        "def text_normalizer(sub_word_transcript):\n",
        "    transcript = sub_word_transcript[0].replace(\"▁\", \"\")\n",
        "    for sub_word in sub_word_transcript[1:]:\n",
        "        if \"▁\" in sub_word:\n",
        "            transcript = transcript + \" \" + sub_word.replace(\"▁\", \"\")\n",
        "        else:\n",
        "            transcript = transcript + sub_word\n",
        "    return transcript\n",
        "intent_text=\"{scenario: \"+text.split()[0].split(\"_\")[0]+\", action: \"+\"_\".join(text.split()[0].split(\"_\")[1:])+\"}\"\n",
        "print(f\"INTENT: {intent_text}\")\n",
        "transcript=text_normalizer(text.split()[1:])\n",
        "print(f\"ASR hypothesis: {transcript}\")\n",
        "print(f\"First pass SLU model fails to predict the correct action.\")"
      ],
      "metadata": {
        "id": "8Ope3h7P3aKz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Download and Load pretrained Second Pass Model"
      ],
      "metadata": {
        "id": "tKB6r3c-398T"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!git lfs clone https://huggingface.co/espnet/slurp_slu_2pass /content/slurp_second_pass_model"
      ],
      "metadata": {
        "id": "gFi7KOFE4HzK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from espnet2.bin.slu_inference import Speech2Understand\n",
        "from transformers import AutoModel, AutoTokenizer\n",
        "speech2text_second_pass_slurp = Speech2Understand.from_pretrained(\n",
        "    slu_train_config=\"/content/slurp_second_pass_model/exp/slu_train_asr_bert_conformer_deliberation_raw_en_word/config.yaml\",\n",
        "    slu_model_file=\"/content/slurp_second_pass_model/exp/slu_train_asr_bert_conformer_deliberation_raw_en_word/valid.acc.ave_10best.pth\",\n",
        "    nbest=1,\n",
        ")"
      ],
      "metadata": {
        "id": "bt9UqWvj60K7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from espnet2.tasks.slu import SLUTask\n",
        "preprocess_fn=SLUTask.build_preprocess_fn(\n",
        "            speech2text_second_pass_slurp.asr_train_args, False\n",
        "        )\n"
      ],
      "metadata": {
        "id": "L4CvcTWHMg23"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "transcript = preprocess_fn.text_cleaner(transcript)\n",
        "tokens = preprocess_fn.transcript_tokenizer.text2tokens(transcript)\n",
        "text_ints = np.array(preprocess_fn.transcript_token_id_converter.tokens2ids(tokens), dtype=np.int64)"
      ],
      "metadata": {
        "id": "eXK8jabBM7KL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "nbests = speech2text_second_pass_slurp(mixwav_mc,torch.tensor(text_ints))\n",
        "text1, *_ = nbests[0]\n",
        "intent_text=\"{scenario: \"+text1.split()[0].split(\"_\")[0]+\", action: \"+\"_\".join(text1.split()[0].split(\"_\")[1:])+\"}\"\n",
        "print(f\"INTENT: {intent_text}\")\n",
        "transcript=text_normalizer(text1.split()[1:])\n",
        "print(f\"ASR hypothesis: {transcript}\")\n",
        "print(f\"Second pass SLU model successfully recognizes the correct action.\")"
      ],
      "metadata": {
        "id": "TCELe9RS_62x"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}