티스토리 뷰
preprocess.py
Argparser를 통해서 데이터셋을 처리할 때 편리하게 만들기
import argparse
from utils import PrepareDataset
def audio_process(config) -> None:
print('Start audio processing')
preprocessor = PrepareDataset()
preprocessor.process_audio(
source_dir=config.target_dir,
remove_original_audio=config.remove_original_audio,
) #이미 만들어 놓은 process_audio 함수에 config를 넣는다.
파일을 처리할 명령어를 만든다.
def file_process(config) -> None:
print('Start file processing')
preprocessor = PrepareDataset()
if config.target_file: #만약에 config에 타겟 파일이 있다면
if not (
config.csv or
config.pkl or
config.split_whole_data or
config.split_train_test
):
print(f'If --target-file (-t) is feed, \
one of --csv, --pkl, --split-train-test (-s) or \
--split-whole-data (-w) must be set.') #명령어
return
if config.csv: #csv가 있다면
preprocessor.save_trn_to_csv(config.target_file)
if config.pkl: #pkl이 있다면
preprocessor.save_trn_to_pkl(config.target_file)
if config.split_whole_data: #split_whole_data가 있다면
preprocessor.split_whole_data(config.target_file)
if config.split_train_test: #split_train_test가 있다면
preprocessor.split_train_test(
target_file=config.target_file,
train_size=config.ratio
)
if config.convert_all_to_utf:
if not config.target_dir:
print('If --convert-all-to-utf8 (-c) flagged, you must feed --target-dir')
preprocessor.convert_all_files_to_utf8(config.target_dir)
if config.remove_all_text_files:
if not config.target_dir:
print(f'If --remove-all-text-files (-R) flagged,\
you must feed --target-dir')
return
preprocessor.remove_all_text_files(
target_dir=config.target_dir,
ext=config.remove_file_type
)
각각의 파서를 정의한다.
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog='Korean speech dataset pre-processor',
description='Process Korean speech dataset'
)
sub_parser = parser.add_subparsers(title='sub-command')
### Parser for sub-command 'audio' ###
parser_audio = sub_parser.add_parser(
'audio',
help='sub-command for audio processing'
)
parser_audio.add_argument(
'--target-dir', '-t',
required=True,
help='directory of audio files'
)
parser_audio.add_argument(
'--remove-original-audio', '-r', #없으면 리무브로 저장
action='store_true',
help='Remove original audio files'
)
parser_audio.set_defaults(func=audio_process)
### Parser for sub-command 'file' ###
parser_file = sub_parser.add_parser(
'file',
help='handling txt encoding, generate pkl/csv file,\
or split file (train/test)'
)
parser_file.add_argument(
'--target-file', '-t',
# required=True,
help='Target file name for processing'
)
parser_file.add_argument(
'--convert-all-to-utf', '-c',
action='store_true',
help='Convert all text files to utf-8 under target_dir'
)
parser_file.add_argument(
'--target-dir', '-d',
# required=True,
help='Target directory for converting file encoding to utf-8\
Use by combining --convert-all-to-utf (-c) flag'
)
parser_file.add_argument(
'--split-whole-data', '-w',
action='store_true',
help='Split whole data file int group'
)
parser_file.add_argument(
'--csv',
action='store_true',
help='Generate csv file'
)
parser_file.add_argument(
'--pkl',
action='store_true',
help='Generate pickle file'
)
parser_file.add_argument(
'--split-train-test', '-s',
action='store_true',
help='Flag split train/test set, \
default: 0.8 (train:test -> 80:20)'
)
parser_file.add_argument(
'--ratio',
type=float,
default=0.8,
help='Split file into .train & .test files'
)
parser_file.add_argument(
'--remove-all-text-files', '-R',
action='store_true',
help='Remove all specific type files under target dir'
)
parser_file.add_argument(
'--remove-file-type',
default='txt',
help='Set remove file type'
)
parser_file.set_defaults(func=file_process)
config = parser.parse_args()
return config
config는 그냥 정의해준다.
if __name__=='__main__':
config = get_parser()
config.func(config)
config.func는 parser_file.set_defaults에서의 func를 해당되는 펑션으로 config를 주라는 뜻이 된다.
python preprocess.py audio -h
>options:
-h, help show this help message and exit
--target-dir TARGET_DIR, -t TARGET_DIR
directory of audio files
--remove-original-audio, -r
그러면 다음과 같이 쓸 수 있다.
python preprocess.py audio --target-dir ./data/audio/speech --remove-original-audio
--로 시작하고 -로 이으면 된다.
-r은 단독으로 뒤에 붙는거 없이 써도 된다.
작은 데이터로 샘플링하기
Wav 파일은 너무 커서, Whisper에 넣기 위한 array로 만들려면 컴퓨터의 메모리가 감당하지 못한다.
따라서 구간 별로 나눠서 해당하는 데이터를 전처리 해야 한다.
def split_whole_data(self, target_file:str) -> None:
'''전체 데이터 파일 (train.trn)을 그룹별로 구분
For example, in train.trn file
KsponSpeech_01/KsponSpeech_0001/KsponSpeech_000001.pcm :: 'some text'
-> this file will be stored in train_KsponSpeech_01.trn
KsponSpeech_02/KsponSpeech_0001/KsponSpeech_000002.pcm :: 'some text'
-> this file will be stored in train_KsponSpeech_02.trn
'''
with open(target_file, 'rt') as f:
lines = f.readlines()
data_group = set()
for line in lines:
data_group.add(line.split('/')[0])
data_group = sorted(list(data_group))
data_dic = { group: [] for group in data_group} # dict comprehension
for line in lines:
data_dic[line.split('/')[0]].append(line)
# Save file seperately
# target_file: data/info/train.trn -> ['data', 'info', 'train.trn']
save_dir = target_file.split('/')[:-1]
save_dir = '/'.join(save_dir)
for group, line_list in data_dic.items():
file_path = os.path.join(save_dir, f'train_{group}.trn')
with open(file_path, 'wt', encoding='utf-8') as f:
for text in line_list:
f.write(text)
print(f'File created -> {file_path}')
print('Done!')
해당하는 train.trn 파일을 연다.
해당 파일은 라인마다 데이터가 있으므로, 데이터 그룹에 넣고, 최상의 폴더의 이름을 가져온다..
그리고 해당 데이터 그룹에서 data_dic을 넣기 위해서 comprehension을 만든다.
이 결과로 group이 data_group의 전체가 나오고, 키는 group이 된다.
따라서 line으로 key를 넣어주면 data_dic에 해당되는 라인을 넣어준다.
그러면 해당 data_dic에는 각 그룹에 해당하는 것만 들어간다.
그렇게 해서 save_dir = target_file의 split을 하고, 마지막 파일이름만 빼고 join해서 경로를 정한다.
그리고 나서 key하고 value가 나오면 file_path - os.path.join을 하고,
with open wt해서 f.write 해주면 된다.
10만개씩 나누는 걸 권장한다. 연속해서 파인튜닝하면 똑같은 결과.
좋은 컴퓨터가 없는 사람들은 나눠서 써야 한다.
데이터셋을 .csv 또는 pickle파일로 저장하기
def get_dataset_dict(self, file_name: str, ext: str = 'wav') -> dict:
'''path_dir에 있는 파일을 dict 형태로 가공하여 리턴
return data_dic = {
'audio': ['file_path1', 'file_path2', ...],
'text': ['text1', 'text2', ...]
}'''
data_dic = {'path': [], 'sentence': []}
print(f'file_name: {file_name}')
with open(file_name, 'rt', encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
audio, text = line.split('::')
audio = audio.strip()
audio = os.path.join(
os.getcwd(), # '/home/kafa46/finetune-tutorial'
self.VOICE_DIR.replace('./', ''), # './data/audio' -> 'data/audio'
audio
)
if audio.endswith('.pcm'):
audio = audio.replace('.pcm', f'.{ext}')
text = text.strip()
data_dic['path'].append(audio)
data_dic['sentence'].append(text)
return data_dic
def save_trn_to_pkl(self, file_name: str) -> None:
'''.trn 파일을 dict로 만든 후 .pkl 바이너리로 그냥 저장(dump)'''
data_dict = self.get_dataset_dict(file_name) #이미 dict로 되어 있다면 안해도 된다.
# pickle file dump
file_name_pickle = file_name + '.dic.pkl'
with open(file_name_pickle, 'wb') as f:
pickle.dump(data_dict, f)
print(f'Dataset is saved via dictionary pickle')
print(f'Dataset path: {file_name_pickle}')
def save_trn_to_csv(self, file_name: str) -> None:
'''.trn 파일을 .csv로 저장'''
data_dic = self.get_dataset_dict(file_name)
file_name_csv = file_name.split('.')[:-1]
file_name_csv = ''.join(file_name_csv) + '.csv'
if file_name.startswith('.'):
file_name_csv = '.' + file_name_csv
data_df = pd.DataFrame(data_dic)
# Modified by Giseop Noh since Whisper dataset needs header ㅠㅠ
# modified: header=False -> header=True
data_df.to_csv(file_name_csv, index=False, header=True)
print(f'Dataset is saved via csv')
print(f'Dataset path: {file_name_csv}')
def split_train_test(self,
target_file: str,
train_size: float = 0.8
) -> None:
'''입력 파일(.trn)을 train/test 분류하여 저장
if train_size is 0.8,
train:test = 80%:20%
'''
with open(target_file, 'rt') as f:
data = f.readlines()
train_num = int(len(data) * train_size)
# If you set header (header=True) in csv file, you need following codes
# - Modified by Giseop Noh since Whisper dataset needs header ㅠㅠ
header = None
if target_file.endswith('.csv'):
header = data[0]
data = data[1:]
train_num = int(len(data)*train_size)
shuffle(data)
data_train = sorted(data[0:train_num])
data_test = sorted(data[train_num:])
# train_set 파일 저장
train_file = target_file.split('.')[:-1]
train_file = ''.join(train_file) + '_train.csv'
if target_file.startswith('.'):
train_file = '.' + train_file
with open(train_file, 'wt', encoding='utf-8') as f:
if header:
f.write(header)
for line in data_train:
f.write(line)
print(f'Train_dataset saved -> {train_file} ({train_size*100:.1f}%)')
# test_set 파일 저장
test_file = target_file.split('.')[:-1]
test_file = ''.join(test_file) + '_test.csv'
if target_file.startswith('.'):
test_file = '.' + test_file
with open(test_file, 'wt', encoding='utf-8') as f:
if header:
f.write(header)
for line in data_test:
f.write(line)
print(f'Test_dataset saved -> {test_file} ({(1.0-train_size)*100:.1f}%)')