Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import pickle | |
| from collections import Counter | |
| class TorchVocab(object): | |
| """ | |
| :property freqs: collections.Counter, コーパス中の単語の出現頻度を保持するオブジェクト | |
| :property stoi: collections.defaultdict, string → id の対応を示す辞書 | |
| :property itos: collections.defaultdict, id → string の対応を示す辞書 | |
| """ | |
| def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>', '<oov>'], | |
| vectors=None, unk_init=None, vectors_cache=None): | |
| """ | |
| :param counter: collections.Counter, データ中に含まれる単語の頻度を計測するためのcounter | |
| :param max_size: int, vocabularyの最大のサイズ. Noneの場合は最大値なし. defaultはNone | |
| :param min_freq: int, vocabulary中の単語の最低出現頻度. この数以下の出現回数の単語はvocabularyに加えられない. | |
| :param specials: list of str, vocabularyにあらかじめ登録するtoken | |
| :param vectors: list of vectors, 事前学習済みのベクトル. ex)Vocab.load_vectors | |
| """ | |
| self.freqs = counter | |
| counter = counter.copy() | |
| min_freq = max(min_freq, 1) | |
| self.itos = list(specials) | |
| # special tokensの出現頻度はvocabulary作成の際にカウントされない | |
| for tok in specials: | |
| del counter[tok] | |
| max_size = None if max_size is None else max_size + len(self.itos) | |
| # まず頻度でソートし、次に文字順で並び替える | |
| words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) | |
| words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) | |
| # 出現頻度がmin_freq未満のものはvocabに加えない | |
| for word, freq in words_and_frequencies: | |
| if freq < min_freq or len(self.itos) == max_size: | |
| break | |
| self.itos.append(word) | |
| # dictのk,vをいれかえてstoiを作成する | |
| self.stoi = {tok: i for i, tok in enumerate(self.itos)} | |
| self.vectors = None | |
| if vectors is not None: | |
| self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) | |
| else: | |
| assert unk_init is None and vectors_cache is None | |
| def __eq__(self, other): | |
| if self.freqs != other.freqs: | |
| return False | |
| if self.stoi != other.stoi: | |
| return False | |
| if self.itos != other.itos: | |
| return False | |
| if self.vectors != other.vectors: | |
| return False | |
| return True | |
| def __len__(self): | |
| return len(self.itos) | |
| def vocab_rerank(self): | |
| self.stoi = {word: i for i, word in enumerate(self.itos)} | |
| def extend(self, v, sort=False): | |
| words = sorted(v.itos) if sort else v.itos | |
| for w in words: | |
| if w not in self.stoi: | |
| self.itos.append(w) | |
| self.stoi[w] = len(self.itos) - 1 | |
| class Vocab(TorchVocab): | |
| def __init__(self, counter, max_size=None, min_freq=1): | |
| self.pad_index = 0 | |
| self.unk_index = 1 | |
| self.eos_index = 2 | |
| self.sos_index = 3 | |
| self.mask_index = 4 | |
| super().__init__(counter, specials=["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"], max_size=max_size, min_freq=min_freq) | |
| # override用 | |
| def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list: | |
| pass | |
| # override用 | |
| def from_seq(self, seq, join=False, with_pad=False): | |
| pass | |
| def load_vocab(vocab_path: str) -> 'Vocab': | |
| with open(vocab_path, "rb") as f: | |
| return pickle.load(f) | |
| def save_vocab(self, vocab_path): | |
| with open(vocab_path, "wb") as f: | |
| pickle.dump(self, f) | |
| # テキストファイルからvocabを作成する | |
| class WordVocab(Vocab): | |
| def __init__(self, texts, max_size=None, min_freq=1): | |
| print("Building Vocab") | |
| counter = Counter() | |
| for line in texts: | |
| if isinstance(line, list): | |
| words = line | |
| else: | |
| words = line.replace("\n", "").replace("\t", "").split() | |
| for word in words: | |
| counter[word] += 1 | |
| super().__init__(counter, max_size=max_size, min_freq=min_freq) | |
| def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False): | |
| if isinstance(sentence, str): | |
| sentence = sentence.split() | |
| seq = [self.stoi.get(word, self.unk_index) for word in sentence] | |
| if with_eos: | |
| seq += [self.eos_index] # this would be index 1 | |
| if with_sos: | |
| seq = [self.sos_index] + seq | |
| origin_seq_len = len(seq) | |
| if seq_len is None: | |
| pass | |
| elif len(seq) <= seq_len: | |
| seq += [self.pad_index for _ in range(seq_len - len(seq))] | |
| else: | |
| seq = seq[:seq_len] | |
| return (seq, origin_seq_len) if with_len else seq | |
| def from_seq(self, seq, join=False, with_pad=False): | |
| words = [self.itos[idx] | |
| if idx < len(self.itos) | |
| else "<%d>" % idx | |
| for idx in seq | |
| if not with_pad or idx != self.pad_index] | |
| return " ".join(words) if join else words | |
| def load_vocab(vocab_path: str) -> 'WordVocab': | |
| with open(vocab_path, "rb") as f: | |
| return pickle.load(f) | |
