blob: bd6ccd3e3a7850c1ede1c67efe0502b7182ec7f1 [file] [log] [blame]
sbarati@apple.com2318ef52017-04-25 00:04:00 +00001/*
2 * Copyright (C) 2017 Apple Inc. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions
6 * are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 *
13 * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
14 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR
17 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
18 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
19 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
20 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
21 * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 */
25
26"use strict";
27
28let currentTime;
29if (this.performance && performance.now)
30 currentTime = function() { return performance.now() };
31else if (this.preciseTime)
32 currentTime = function() { return preciseTime() * 1000; };
33else
34 currentTime = function() { return +new Date(); };
35
36class MLBenchmark {
37 constructor() { }
38
39 runIteration()
40 {
41 let Matrix = MLMatrix;
42 let ACTIVATION_FUNCTIONS = FeedforwardNeuralNetworksActivationFunctions;
43
44 function run() {
45
46 let it = (name, f) => {
47 f();
48 };
49
50 function assert(b) {
51 if (!b)
52 throw new Error("Bad");
53 }
54
55 var functions = Object.keys(ACTIVATION_FUNCTIONS);
56
57 it('Training the neural network with XOR operator', function () {
58 var trainingSet = new Matrix([[0, 0], [0, 1], [1, 0], [1, 1]]);
59 var predictions = [false, true, true, false];
60
61 for (var i = 0; i < functions.length; ++i) {
62 var options = {
63 hiddenLayers: [4],
64 iterations: 40,
65 learningRate: 0.3,
66 activation: functions[i]
67 };
68 var xorNN = new FeedforwardNeuralNetwork(options);
69
70 xorNN.train(trainingSet, predictions);
71 var results = xorNN.predict(trainingSet);
72 }
73 });
74
75 it('Training the neural network with AND operator', function () {
76 var trainingSet = [[0, 0], [0, 1], [1, 0], [1, 1]];
77 var predictions = [[1, 0], [1, 0], [1, 0], [0, 1]];
78
79 for (var i = 0; i < functions.length; ++i) {
80 var options = {
81 hiddenLayers: [3],
82 iterations: 75,
83 learningRate: 0.3,
84 activation: functions[i]
85 };
86 var andNN = new FeedforwardNeuralNetwork(options);
87 andNN.train(trainingSet, predictions);
88
89 var results = andNN.predict(trainingSet);
90 }
91 });
92
93 it('Export and import', function () {
94 var trainingSet = [[0, 0], [0, 1], [1, 0], [1, 1]];
95 var predictions = [0, 1, 1, 1];
96
97 for (var i = 0; i < functions.length; ++i) {
98 var options = {
99 hiddenLayers: [4],
100 iterations: 40,
101 learningRate: 0.3,
102 activation: functions[i]
103 };
104 var orNN = new FeedforwardNeuralNetwork(options);
105 orNN.train(trainingSet, predictions);
106
107 var model = JSON.parse(JSON.stringify(orNN));
108 var networkNN = FeedforwardNeuralNetwork.load(model);
109
110 var results = networkNN.predict(trainingSet);
111 }
112 });
113
114 it('Multiclass clasification', function () {
115 var trainingSet = [[0, 0], [0, 1], [1, 0], [1, 1]];
116 var predictions = [2, 0, 1, 0];
117
118 for (var i = 0; i < functions.length; ++i) {
119 var options = {
120 hiddenLayers: [4],
121 iterations: 40,
122 learningRate: 0.5,
123 activation: functions[i]
124 };
125 var nn = new FeedforwardNeuralNetwork(options);
126 nn.train(trainingSet, predictions);
127
128 var result = nn.predict(trainingSet);
129 }
130 });
131
132 it('Big case', function () {
133 var trainingSet = [[1, 1], [1, 2], [2, 1], [2, 2], [3, 1], [1, 3], [1, 4], [4, 1],
134 [6, 1], [6, 2], [6, 3], [6, 4], [6, 5], [5, 5], [4, 5], [3, 5]];
135 var predictions = [[1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [1, 0],
136 [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]];
137 for (var i = 0; i < functions.length; ++i) {
138 var options = {
139 hiddenLayers: [20],
140 iterations: 60,
141 learningRate: 0.01,
142 activation: functions[i]
143 };
144 var nn = new FeedforwardNeuralNetwork(options);
145 nn.train(trainingSet, predictions);
146
147 var result = nn.predict([[5, 4]]);
148
149 assert(result[0][0] < result[0][1]);
150 }
151 });
152 }
153
154 run();
155 }
156}
157
158function runBenchmark()
159{
160 const numIterations = 60;
161
162 let before = currentTime();
163
164 let benchmark = new Benchmark();
165
166 for (let iteration = 0; iteration < numIterations; ++iteration)
167 benchmark.runIteration();
168
169 let after = currentTime();
170 return after - before;
171}