TensorFlowサンプルコード
修訂 | 6d18270d4073ebe203d054845675b35235aa23e8 (tree) |
---|---|
時間 | 2018-01-18 20:36:56 |
作者 | hylom <hylom@hylo...> |
Commiter | hylom |
add TFRecord samples
@@ -0,0 +1,104 @@ | ||
1 | +#!/usr/bin/env python | |
2 | +# -*- coding: utf-8 -*- | |
3 | +import argparse | |
4 | +import os | |
5 | + | |
6 | +import tensorflow as tf | |
7 | + | |
8 | +def main(): | |
9 | + # 引数をパースする | |
10 | + p = argparse.ArgumentParser(description='convert images to TFRecord format') | |
11 | + p.add_argument('dimension', | |
12 | + type=lambda x: [int(r) for r in x.split("x")], | |
13 | + help='image dimension. example: 10x10') | |
14 | + p.add_argument('label', | |
15 | + type=int, | |
16 | + help='label.') | |
17 | + p.add_argument('output', | |
18 | + #type=argparse.FileType('w'), | |
19 | + help='output file') | |
20 | + p.add_argument('target_dir', | |
21 | + nargs='+', | |
22 | + help='target directory') | |
23 | + args = p.parse_args() | |
24 | + if len(args.dimension) != 2: | |
25 | + raise argparse.ArgumentTypeError("dimension must be <num>x<num>") | |
26 | + | |
27 | + # 計算グラフを構築する | |
28 | + (width, height) = args.dimension | |
29 | + (filepath, padded_image) = _create_comp_graph(width, height) | |
30 | + | |
31 | + # 出力先の用意 | |
32 | + writer = tf.python_io.TFRecordWriter(args.output) | |
33 | + | |
34 | + # セッションを実行 | |
35 | + sess = tf.Session() | |
36 | + | |
37 | + # 指定したディレクトリ内のファイルを列挙する | |
38 | + for dirname in args.target_dir: | |
39 | + for filename in os.listdir(dirname): | |
40 | + pathname = os.path.join(dirname, filename) | |
41 | + print("process {}...".format(pathname)) | |
42 | + try: | |
43 | + image = sess.run(padded_image, {filepath: pathname}) | |
44 | + except tf.errors.InvalidArgumentError: | |
45 | + # 読み込みに失敗したらメッセージを出力して継続する | |
46 | + print("{}: invalid jpeg file. ignored.".format(pathname)) | |
47 | + continue | |
48 | + | |
49 | + features = tf.train.Features(feature={ | |
50 | + 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[args.label])), | |
51 | + 'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])), | |
52 | + 'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])), | |
53 | + 'raw_image': tf.train.Feature(float_list=tf.train.FloatList(value=image.reshape(width*height*3))), | |
54 | + }) | |
55 | + example = tf.train.Example(features=features) | |
56 | + writer.write(example.SerializeToString()) | |
57 | + | |
58 | + # 終了 | |
59 | + writer.close() | |
60 | + print("done") | |
61 | + | |
62 | +# 指定したファイルを読み出して指定したサイズにリサイズする | |
63 | +def _create_comp_graph(width, height): | |
64 | + filepath = tf.placeholder(tf.string, name="pathname") | |
65 | + content = tf.read_file(filepath) | |
66 | + raw_image = tf.image.decode_jpeg(content, channels=3) | |
67 | + # raw_imageのデータ型はuint8 | |
68 | + | |
69 | + # 画像のサイズを[height, width, channels]の形で取得 | |
70 | + shape = tf.shape(raw_image) | |
71 | + raw_height = tf.to_float(tf.slice(shape, [0], [1])) | |
72 | + raw_width = tf.to_float(tf.slice(shape, [1], [1])) | |
73 | + | |
74 | + # 画像サイズと出力サイズのアスペクト比を比較し、 | |
75 | + # 画像サイズのほうが大きければ高さを、 | |
76 | + # 小さければ幅を出力サイズにそろえるよう | |
77 | + # 拡大縮小比を求める | |
78 | + # 幅/高さは整数(int)型データなので、適宜float型に変換する | |
79 | + aspect_ratio = float(height) / width | |
80 | + raw_aspect_ratio = raw_height / raw_width | |
81 | + scale = tf.cond(tf.reduce_any(raw_aspect_ratio > [aspect_ratio]), | |
82 | + lambda: [height] / raw_height, | |
83 | + lambda: [width] / raw_width | |
84 | + ) | |
85 | + | |
86 | + # 求めた比率で画像をリサイズする | |
87 | + # resized_imageはfloat32型となる | |
88 | + new_size = tf.to_int32(tf.concat([raw_height, raw_width], 0) * scale) | |
89 | + resized_image = tf.image.resize_images(raw_image, new_size) | |
90 | + | |
91 | + # 余白を0で埋めて指定したサイズにそろえる | |
92 | + padded_image = tf.image.resize_image_with_crop_or_pad( | |
93 | + resized_image, | |
94 | + height, | |
95 | + width | |
96 | + ) | |
97 | + | |
98 | + return (filepath, padded_image) | |
99 | + | |
100 | + | |
101 | +if __name__ == "__main__": | |
102 | + main() | |
103 | + | |
104 | + |
@@ -0,0 +1,10 @@ | ||
1 | +#!/bin/sh | |
2 | +DATA_DIR=../data2 | |
3 | + | |
4 | +LABEL=0 | |
5 | +for i in cat dog monkey; do | |
6 | + echo convert $i:$LABEL | |
7 | + ./convert_tfr.py 100x100 $LABEL ${DATA_DIR}/teach_$i.tfrecord ${DATA_DIR}/$i/teach | |
8 | + ./convert_tfr.py 100x100 $LABEL ${DATA_DIR}/test_$i.tfrecord ${DATA_DIR}/$i/test | |
9 | + LABEL=$(expr $LABEL + 1) | |
10 | +done |
@@ -0,0 +1,181 @@ | ||
1 | +#!/usr/bin/env python | |
2 | +# -*- coding: utf-8 -*- | |
3 | +import sys | |
4 | +import tensorflow as tf | |
5 | + | |
6 | +INPUT_WIDTH = 100 | |
7 | +INPUT_HEIGHT = 100 | |
8 | +INPUT_CHANNELS = 3 | |
9 | + | |
10 | +INPUT_SIZE = INPUT_WIDTH * INPUT_HEIGHT * INPUT_CHANNELS | |
11 | +W1_SIZE = 200 | |
12 | +OUTPUT_SIZE = 3 | |
13 | +LABEL_SIZE = OUTPUT_SIZE | |
14 | + | |
15 | +TEACH_FILES = ["../data2/teach_cat.tfrecord", | |
16 | + "../data2/teach_dog.tfrecord", | |
17 | + "../data2/teach_monkey.tfrecord"] | |
18 | +TEST_FILES = ["../data2/test_cat.tfrecord", | |
19 | + "../data2/test_dog.tfrecord", | |
20 | + "../data2/test_monkey.tfrecord"] | |
21 | + | |
22 | +MODEL_NAME = "./neural_model" | |
23 | + | |
24 | +tf.set_random_seed(1111) | |
25 | + | |
26 | +# モデルを定義 | |
27 | +with tf.variable_scope('model') as scope: | |
28 | + x1 = tf.placeholder(dtype=tf.float32) | |
29 | + y = tf.placeholder(dtype=tf.float32) | |
30 | + | |
31 | + # 第2層 | |
32 | + W1 = tf.get_variable("W1", | |
33 | + shape=[INPUT_SIZE, W1_SIZE], | |
34 | + dtype=tf.float32, | |
35 | + initializer=tf.random_normal_initializer(stddev=0.01)) | |
36 | + b1 = tf.get_variable("b1", | |
37 | + shape=[W1_SIZE], | |
38 | + dtype=tf.float32, | |
39 | + initializer=tf.random_normal_initializer(stddev=0.01)) | |
40 | + x2 = tf.sigmoid(tf.matmul(x1, W1) + b1, name="x2") | |
41 | + | |
42 | + # 第3層 | |
43 | + W2 = tf.get_variable("W2", | |
44 | + shape=[W1_SIZE, OUTPUT_SIZE], | |
45 | + dtype=tf.float32, | |
46 | + initializer=tf.random_normal_initializer(stddev=0.01)) | |
47 | + b2 = tf.get_variable("b2", | |
48 | + shape=[OUTPUT_SIZE], | |
49 | + dtype=tf.float32, | |
50 | + initializer=tf.random_normal_initializer(stddev=0.01)) | |
51 | + x3 = tf.nn.softmax(tf.matmul(x2, W2) + b2, name="x3") | |
52 | + | |
53 | + # コスト関数 | |
54 | + cross_entropy = -tf.reduce_sum(y * tf.log(x3), name="cross_entropy") | |
55 | + tf.summary.scalar('cross_entropy', cross_entropy) | |
56 | + | |
57 | + # 正答率 | |
58 | + # 出力テンソルの中でもっとも値が大きいもののインデックスが | |
59 | + # 正答と等しいかどうかを計算する | |
60 | + correct = tf.equal(tf.argmax(x3,1), tf.argmax(y,1), name="correct") | |
61 | + accuracy = tf.reduce_mean(tf.cast(correct, "float"), name="accuracy") | |
62 | + tf.summary.scalar('accuracy', accuracy) | |
63 | + | |
64 | + # 最適化アルゴリズムを定義 | |
65 | + global_step = tf.Variable(0, name='global_step', trainable=False) | |
66 | + optimizer = tf.train.GradientDescentOptimizer(1e-4, name="optimizer") | |
67 | + minimize = optimizer.minimize(cross_entropy, global_step=global_step, name="minimize") | |
68 | + | |
69 | + # 学習結果を保存するためのオブジェクトを用意 | |
70 | + saver = tf.train.Saver() | |
71 | + | |
72 | +# 読み込んだデータの変換用関数 | |
73 | +def map_dataset(serialized): | |
74 | + features = { | |
75 | + 'label': tf.FixedLenFeature([], tf.int64), | |
76 | + 'height': tf.FixedLenFeature([], tf.int64), | |
77 | + 'width': tf.FixedLenFeature([], tf.int64), | |
78 | + 'raw_image': tf.FixedLenFeature([INPUT_SIZE], tf.float32), | |
79 | + } | |
80 | + parsed = tf.parse_single_example(serialized, features) | |
81 | + | |
82 | + # 読み込んだデータを変換する | |
83 | + raw_label = tf.cast(parsed['label'], tf.int32) | |
84 | + label = tf.reshape(tf.slice(tf.eye(LABEL_SIZE), | |
85 | + [raw_label, 0], | |
86 | + [1, LABEL_SIZE]), | |
87 | + [LABEL_SIZE]) | |
88 | + | |
89 | + image = parsed['raw_image'] | |
90 | + return (image, label, raw_label) | |
91 | + | |
92 | +## データセットの読み込み | |
93 | +# 読み出すデータは各データ200件ずつ×3で計600件 | |
94 | +dataset = tf.data.TFRecordDataset(TEACH_FILES)\ | |
95 | + .map(map_dataset)\ | |
96 | + .batch(600) | |
97 | + | |
98 | +# データにアクセスするためのイテレータを作成 | |
99 | +iterator = dataset.make_one_shot_iterator() | |
100 | +item = iterator.get_next() | |
101 | + | |
102 | +# セッションの作成 | |
103 | +sess = tf.Session() | |
104 | + | |
105 | +# 変数の初期化を実行する | |
106 | +sess.run(tf.global_variables_initializer()) | |
107 | + | |
108 | +# 学習結果を保存したファイルが存在するかを確認し、 | |
109 | +# 存在していればそれを読み出す | |
110 | +latest_filename = tf.train.latest_checkpoint("./") | |
111 | +if latest_filename: | |
112 | + print("load saved model {}".format(latest_filename)) | |
113 | + saver.restore(sess, latest_filename) | |
114 | + | |
115 | +# サマリを取得するための処理 | |
116 | +summary_op = tf.summary.merge_all() | |
117 | +summary_writer = tf.summary.FileWriter('data', graph=sess.graph) | |
118 | + | |
119 | +# 学習用データを読み出す | |
120 | +(dataset_x, dataset_y, values_y) = sess.run(item) | |
121 | + | |
122 | + | |
123 | +steps = tf.train.global_step(sess, global_step) | |
124 | + | |
125 | +if steps == 0: | |
126 | + # 初期状態を記録 | |
127 | + xe, acc, summary = sess.run([cross_entropy, accuracy, summary_op], {x1: dataset_x, y: dataset_y}) | |
128 | + print("CROSS ENTROPY({}): {}".format(0, xe)) | |
129 | + print(" ACCURACY({}): {}".format(0, acc)) | |
130 | + summary_writer.add_summary(summary, global_step=0) | |
131 | + | |
132 | +# 学習を開始 | |
133 | +for i in range(10): | |
134 | + for j in range(100): | |
135 | + sess.run(minimize, {x1: dataset_x, y: dataset_y}) | |
136 | + | |
137 | + # 途中経過を取得・保存 | |
138 | + xe, acc, summary = sess.run([cross_entropy, accuracy, summary_op], {x1: dataset_x, y: dataset_y}) | |
139 | + print("CROSS ENTROPY({}): {}".format(steps + 100 * (i+1), xe)) | |
140 | + print(" ACCURACY({}): {}".format(steps + 100 * (i+1), acc)) | |
141 | + summary_writer.add_summary(summary, global_step=tf.train.global_step(sess, global_step)) | |
142 | + | |
143 | +# 学習終了 | |
144 | +# 結果を保存する | |
145 | +save_path = saver.save(sess, MODEL_NAME, global_step=tf.train.global_step(sess, global_step)) | |
146 | +print("Model saved to {}".format(save_path)) | |
147 | + | |
148 | +## 結果の出力 | |
149 | + | |
150 | +# 学習に使用したデータを入力した場合の | |
151 | +# 正答率を計算する | |
152 | +print("----result with teaching data----") | |
153 | + | |
154 | +print("assumed label:") | |
155 | +print(sess.run(tf.argmax(x3, 1), feed_dict={x1: dataset_x})) | |
156 | +print("real label:") | |
157 | +print(sess.run(tf.argmax(y, 1), feed_dict={y: dataset_y})) | |
158 | +print("accuracy:", sess.run(accuracy, feed_dict={x1: dataset_x, y: dataset_y})) | |
159 | + | |
160 | + | |
161 | +# テスト用データを入力した場合の | |
162 | +# 正答率を計算する | |
163 | +print("----result with test data----") | |
164 | + | |
165 | +## テスト用データセットの読み込み | |
166 | +# テストデータは50×3=150件 | |
167 | +dataset2 = tf.data.TFRecordDataset(TEST_FILES)\ | |
168 | + .map(map_dataset)\ | |
169 | + .batch(150) | |
170 | +iterator2 = dataset2.make_one_shot_iterator() | |
171 | +item2 = iterator2.get_next() | |
172 | +(dataset_x, dataset_y, values_y) = sess.run(item2) | |
173 | + | |
174 | +# 正答率を出力 | |
175 | +print("assumed label:") | |
176 | +print(sess.run(tf.argmax(x3, 1), feed_dict={x1: dataset_x})) | |
177 | +print("real label:") | |
178 | +print(sess.run(tf.argmax(y, 1), feed_dict={y: dataset_y})) | |
179 | +print("accuracy:", sess.run(accuracy, feed_dict={x1: dataset_x, y: dataset_y})) | |
180 | + | |
181 | + |