Export model.state_dict()

When to use it

During model training, we save checkpoints periodically to disk.

A checkpoint contains the following information:

  • model.state_dict()

  • optimizer.state_dict()

  • and some other information related to training

When we need to resume the training process from some point, we need a checkpoint. However, if we want to publish the model for inference, then only model.state_dict() is needed. In this case, we need to strip all other information except model.state_dict() to reduce the file size of the published model.

How to export

Every recipe contains a file export.py that you can use to export model.state_dict() by taking some checkpoints as inputs.

Hint

Each export.py contains well-documented usage information.

In the following, we use https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless3/export.py as an example.

Note

The steps for other recipes are almost the same.

cd egs/librispeech/ASR

./pruned_transducer_stateless3/export.py \
  --exp-dir ./pruned_transducer_stateless3/exp \
  --tokens data/lang_bpe_500/tokens.txt \
  --epoch 20 \
  --avg 10

will generate a file pruned_transducer_stateless3/exp/pretrained.pt, which is a dict containing {"model": model.state_dict()} saved by torch.save().

How to use the exported model

For each recipe, we provide pretrained models hosted on huggingface. You can find links to pretrained models in RESULTS.md of each dataset.

In the following, we demonstrate how to use the pretrained model from https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13.

cd egs/librispeech/ASR

git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13

After cloning the repo with git lfs, you will find several files in the folder icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp that have a prefix pretrained-. Those files contain model.state_dict() exported by the above export.py.

In each recipe, there is also a file pretrained.py, which can use pretrained-xxx.pt to decode waves. The following is an example:

cd egs/librispeech/ASR

./pruned_transducer_stateless3/pretrained.py \
   --checkpoint ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/pretrained-iter-1224000-avg-14.pt \
   --tokens ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
   --method greedy_search \
   ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav \
   ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav \
   ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav

The above commands show how to use the exported model with pretrained.py to decode multiple sound files. Its output is given as follows for reference:

2022-10-13 19:09:02,233 INFO [pretrained.py:265] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'encoder_dim': 512, 'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 12, 'decoder_dim': 512, 'joiner_dim': 512, 'model_warm_step': 3000, 'env_info': {'k2-version': '1.21', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '4810e00d8738f1a21278b0156a42ff396a2d40ac', 'k2-git-date': 'Fri Oct 7 19:35:03 2022', 'lhotse-version': '1.3.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': False, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'onnx-doc-1013', 'icefall-git-sha1': 'c39cba5-dirty', 'icefall-git-date': 'Thu Oct 13 15:17:20 2022', 'icefall-path': '/k2-dev/fangjun/open-source/icefall-master', 'k2-path': '/k2-dev/fangjun/open-source/k2-master/k2/python/k2/__init__.py', 'lhotse-path': '/ceph-fj/fangjun/open-source-2/lhotse-jsonl/lhotse/__init__.py', 'hostname': 'de-74279-k2-test-4-0324160024-65bfd8b584-jjlbn', 'IP address': '10.177.74.203'}, 'checkpoint': './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/pretrained-iter-1224000-avg-14.pt', 'bpe_model': './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model', 'method': 'greedy_search', 'sound_files': ['./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav', './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav', './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav'], 'sample_rate': 16000, 'beam_size': 4, 'beam': 4, 'max_contexts': 4, 'max_states': 8, 'context_size': 2, 'max_sym_per_frame': 1, 'simulate_streaming': False, 'decode_chunk_size': 16, 'left_context': 64, 'dynamic_chunk_training': False, 'causal_convolution': False, 'short_chunk_size': 25, 'num_left_chunks': 4, 'blank_id': 0, 'unk_id': 2, 'vocab_size': 500}
2022-10-13 19:09:02,233 INFO [pretrained.py:271] device: cpu
2022-10-13 19:09:02,233 INFO [pretrained.py:273] Creating model
2022-10-13 19:09:02,612 INFO [train.py:458] Disable giga
2022-10-13 19:09:02,623 INFO [pretrained.py:277] Number of model parameters: 78648040
2022-10-13 19:09:02,951 INFO [pretrained.py:285] Constructing Fbank computer
2022-10-13 19:09:02,952 INFO [pretrained.py:295] Reading sound files: ['./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav', './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav', './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav']
2022-10-13 19:09:02,957 INFO [pretrained.py:301] Decoding started
2022-10-13 19:09:06,700 INFO [pretrained.py:329] Using greedy_search
2022-10-13 19:09:06,912 INFO [pretrained.py:388]
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS

./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN

./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION


2022-10-13 19:09:06,912 INFO [pretrained.py:390] Decoding Done

Use the exported model to run decode.py

When we publish the model, we always note down its WERs on some test dataset in RESULTS.md. This section describes how to use the pretrained model to reproduce the WER.

cd egs/librispeech/ASR
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13

cd icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp
ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt
cd ../..

We create a symlink with name epoch-9999.pt to pretrained-iter-1224000-avg-14.pt, so that we can pass --epoch 9999 --avg 1 to decode.py in the following command:

./pruned_transducer_stateless3/decode.py \
    --epoch 9999 \
    --avg 1 \
    --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp \
    --lang-dir ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500 \
    --max-duration 600 \
    --decoding-method greedy_search

You will find the decoding results in ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/greedy_search.

Caution

For some recipes, you also need to pass --use-averaged-model False to decode.py. The reason is that the exported pretrained model is already the averaged one.

Hint

Before running decode.py, we assume that you have already run prepare.sh to prepare the test dataset.