class ColorAccessibilityModel {
inputTensor;
targetTensor;
setupSession(trainingSet) {
const graph = new Graph();
this.inputTensor = graph.placeholder('input RGB value', [3]);
this.targetTensor = graph.placeholder('output classifier', [2]);
let connectedLayer = this.createConnectedLayer(graph, this.inputTensor, 0, 64);
connectedLayer = this.createConnectedLayer(graph, connectedLayer, 1, 32);
connectedLayer = this.createConnectedLayer(graph, connectedLayer, 2, 16);
}
createConnectedLayer(
graph,
inputLayer,
layerIndex,
units,
activationFunction
) {
return graph.layers.dense(
`fully_connected_${layerIndex}`,
inputLayer,
units,
activationFunction ? activationFunction : (x) => graph.relu(x)
);
}
...
}
export default ColorAccessibilityModel;
第六步,創(chuàng)建輸出二分類的層。它有兩個輸出單元,每一個表示一個離散的值(黑色、白色)。
class ColorAccessibilityModel {
inputTensor;
targetTensor;
predictionTensor;
setupSession(trainingSet) {
const graph = new Graph();
this.inputTensor = graph.placeholder('input RGB value', [3]);
this.targetTensor = graph.placeholder('output classifier', [2]);
let connectedLayer = this.createConnectedLayer(graph, this.inputTensor, 0, 64);
connectedLayer = this.createConnectedLayer(graph, connectedLayer, 1, 32);
connectedLayer = this.createConnectedLayer(graph, connectedLayer, 2, 16);
this.predictionTensor = this.createConnectedLayer(graph, connectedLayer, 3, 2);
}
...
}
export default ColorAccessibilityModel;
第七步,聲明一個代價張量(cost tensor),以定義損失函數(shù)。在這個案例中,代價張量是均方誤差。它使用訓(xùn)練集的目標張量(標簽)和訓(xùn)練算法得到的預(yù)測張量來計算代價。
class ColorAccessibilityModel {
inputTensor;
targetTensor;
predictionTensor;
costTensor;
setupSession(trainingSet) {
const graph = new Graph();
this.inputTensor = graph.placeholder('input RGB value', [3]);
this.targetTensor = graph.placeholder('output classifier', [2]);
let connectedLayer = this.createConnectedLayer(graph, this.inputTensor, 0, 64);
connectedLayer = this.createConnectedLayer(graph, connectedLayer, 1, 32);
connectedLayer = this.createConnectedLayer(graph, connectedLayer, 2, 16);
this.predictionTensor = this.createConnectedLayer(graph, connectedLayer, 3, 2);
this.costTensor = graph.meanSquaredCost(this.targetTensor, this.predictionTensor);
}
...
}
export default ColorAccessibilityModel;
最后但并非不重要的一步,設(shè)置架構(gòu)圖的相關(guān)會話。之后,你就可以開始準備為訓(xùn)練階段導(dǎo)入訓(xùn)練集了。
import {
Graph,
Session,
NDArrayMathGPU,
} from 'deeplearn';
class ColorAccessibilityModel {
session;
inputTensor;
targetTensor;
predictionTensor;
costTensor;
setupSession(trainingSet) {
const graph = new Graph();
this.inputTensor = graph.placeholder('input RGB value', [3]);
this.targetTensor = graph.placeholder('output classifier', [2]);
let connectedLayer = this.createConnectedLayer(graph, this.inputTensor, 0, 64);
connectedLayer = this.createConnectedLayer(graph, connectedLayer, 1, 32);
connectedLayer = this.createConnectedLayer(graph, connectedLayer, 2, 16);
this.predictionTensor = this.createConnectedLayer(graph, connectedLayer, 3, 2);
this.costTensor = graph.meanSquaredCost(this.targetTensor, this.predictionTensor);
this.session = new Session(graph, math);
this.prepareTrainingSet(trainingSet);
}
prepareTrainingSet(trainingSet) {
...
}
...
}
export default ColorAccessibilityModel;
不過目前在準備神經(jīng)網(wǎng)絡(luò)的訓(xùn)練集之前,設(shè)置還沒完成。
首先,你可以在 GPU 數(shù)學(xué)計算環(huán)境中使用回調(diào)函數(shù)(callback function)來支持計算,但這并不是強制性的,可自主選擇。
import {
Graph,
Session,
NDArrayMathGPU,
} from 'deeplearn';
const math = new NDArrayMathGPU();
class ColorAccessibilityModel {
session;
inputTensor;
targetTensor;
predictionTensor;
costTensor;
...
prepareTrainingSet(trainingSet) {
math.scope(() => {
...
});
}
...
}
export default ColorAccessibilityModel;
其次,你可以解構(gòu)訓(xùn)練集的輸入和輸出(標簽,也稱為目標)以將其轉(zhuǎn)換成神經(jīng)網(wǎng)絡(luò)可讀的格式。deeplearn.js 的數(shù)學(xué)計算使用內(nèi)置的 NDArrays。你可以把它們理解為數(shù)組矩陣中的簡單數(shù)組或向量。此外,輸入數(shù)組的顏色被歸一化以提高神經(jīng)網(wǎng)絡(luò)的性能。
評論