Check in token-level language model via tflite ffi

Local test run after running `gclient sync`:

ariaye@ariaye1:~/sdk/sdk$ dart pkg/analysis_server/test/services/completion/dart/language_model_test.dart
00:00 +0: calculates lookback
INFO: Initialized TensorFlow Lite runtime.
00:00 +1: predict with defaults
00:01 +2: predict with confidence scores
00:03 +3: predict when no previous tokens
00:04 +4: All tests passed!


Change-Id: I4181bea09cf8fec74d03bba4f83cd26dac818f30
Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/109662
Reviewed-by: Brian Wilkerson <brianwilkerson@google.com>
This commit is contained in:
Ari Aye 2019-07-25 18:39:51 +00:00
parent b9ab8efca6
commit 5d5a0c0164
6 changed files with 204 additions and 0 deletions

View file

@ -100,6 +100,7 @@ test_process:third_party/pkg/test_process/lib
test_reflective_loader:third_party/pkg/test_reflective_loader/lib
test_runner:pkg/test_runner/lib
testing:pkg/testing/lib
tflite_native:third_party/pkg/tflite_native/lib
typed_data:third_party/pkg/typed_data/lib
unittest:third_party/pkg/unittest/lib
usage:third_party/pkg/usage/lib

13
DEPS
View file

@ -135,6 +135,7 @@ vars = {
"term_glyph_tag": "1.0.1",
"test_reflective_loader_tag": "0.1.8",
"test_tag": "test-v1.6.4",
"tflite_native_rev": "712b8a93fbb4caf83ffed37f154da88c2a517a91",
"typed_data_tag": "1.1.6",
"unittest_rev": "2b8375bc98bb9dc81c539c91aaea6adce12e1072",
"usage_tag": "3.4.0",
@ -360,6 +361,8 @@ deps = {
Var("dart_git") + "term_glyph.git" + "@" + Var("term_glyph_tag"),
Var("dart_root") + "/third_party/pkg/test":
Var("dart_git") + "test.git" + "@" + Var("test_tag"),
Var("dart_root") + "/third_party/pkg/tflite_native":
Var("dart_git") + "tflite_native.git" + "@" + Var("tflite_native_rev"),
Var("dart_root") + "/third_party/pkg/test_descriptor":
Var("dart_git") + "test_descriptor.git" + "@" + Var("test_descriptor_tag"),
Var("dart_root") + "/third_party/pkg/test_process":
@ -399,6 +402,16 @@ deps = {
"dep_type": "cipd",
},
Var("dart_root") + "/pkg/analysis_server/language_model": {
"packages": [
{
"package": "dart/language_model",
"version": "KB68QHR1SKtopACaf3TFcu9MusRbwWqs0L1m_urGLL4C",
}
],
"dep_type": "cipd",
},
Var("dart_root") + "/buildtools": {
"packages": [
{

View file

@ -0,0 +1,110 @@
// Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file
// for details. All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.
import 'dart:io';
import 'dart:convert';
import 'dart:typed_data';
import 'package:path/path.dart' as path;
import 'package:quiver/check.dart';
import 'package:tflite_native/tflite.dart' as tfl;
/// Interface to TensorFlow-based Dart language model for next-token prediction.
class LanguageModel {
static const _defaultCompletions = 100;
final tfl.Interpreter _interpreter;
final Map<String, int> _word2idx;
final Map<int, String> _idx2word;
final int _lookback;
LanguageModel._(
this._interpreter, this._word2idx, this._idx2word, this._lookback);
/// Number of previous tokens to look at during predictions.
int get lookback => _lookback;
/// Number of completion results to return during predictions.
int get completions => _defaultCompletions;
/// Load model from directory.
factory LanguageModel.load(String directory) {
// Load model.
final interpreter =
tfl.Interpreter.fromFile(path.join(directory, 'model.tflite'));
interpreter.allocateTensors();
// Load word2idx mapping for input.
final word2idx = json
.decode(File(path.join(directory, 'word2idx.json')).readAsStringSync())
.cast<String, int>();
// Load idx2word mapping for output.
final idx2word = json
.decode(File(path.join(directory, 'idx2word.json')).readAsStringSync())
.map<int, String>((k, v) => MapEntry<int, String>(int.parse(k), v));
// Get lookback size from model input tensor shape.
final tensorShape = interpreter.getInputTensors().single.shape;
checkArgument(tensorShape.length == 2 && tensorShape.first == 1,
message:
'tensor shape $tensorShape does not match the expected [1, X]');
final lookback = tensorShape.last;
return LanguageModel._(interpreter, word2idx, idx2word, lookback);
}
/// Tear down the interpreter.
void close() {
_interpreter.delete();
}
/// Predicts the next token to follow a list of precedent tokens
///
/// Returns a list of tokens, sorted by most probable first.
List<String> predict(Iterable<String> tokens) =>
predictWithScores(tokens).keys.toList();
/// Predicts the next token with confidence scores.
///
/// Returns an ordered map of tokens to scores, sorted by most probable first.
Map<String, double> predictWithScores(Iterable<String> tokens) {
final tensorIn = _interpreter.getInputTensors().single;
tensorIn.data = _transformInput(tokens);
_interpreter.invoke();
final tensorOut = _interpreter.getOutputTensors().single;
return _transformOutput(tensorOut.data);
}
/// Transforms tokens to data bytes that can be used as interpreter input.
List<int> _transformInput(Iterable<String> tokens) {
// Replace out of vocabulary tokens.
final sanitizedTokens = tokens
.map((token) => _word2idx.containsKey(token) ? token : '<unknown>');
// Get indexes (as floats).
final indexes = Float32List(lookback)
..setAll(0, sanitizedTokens.map((token) => _word2idx[token].toDouble()));
// Get bytes
return Uint8List.view(indexes.buffer);
}
/// Transforms interpreter output data to map of tokens to scores.
Map<String, double> _transformOutput(List<int> databytes) {
// Get bytes.
final bytes = Uint8List.fromList(databytes);
// Get scores (as floats)
final probabilities = Float32List.view(bytes.buffer);
// Get indexes with scores, sorted by scores (descending)
final entries = probabilities.asMap().entries.toList()
..sort((a, b) => b.value.compareTo(a.value));
// Get tokens with scores, limiting the length.
return Map.fromEntries(entries.sublist(0, completions))
.map((k, v) => MapEntry(_idx2word[k], v));
}
}

View file

@ -19,6 +19,7 @@ dependencies:
source_span: any
package_config: any
path: any
tflite_native: any
watcher: any
yaml: any

View file

@ -0,0 +1,77 @@
// Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file
// for details. All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.
import 'dart:ffi';
import 'dart:io';
import 'package:analysis_server/src/services/completion/dart/language_model.dart';
import 'package:path/path.dart' as path;
import 'package:test/test.dart';
import 'package:test_reflective_loader/test_reflective_loader.dart';
final directory =
Platform.script.resolve('../../../../language_model/lexeme').path;
const expectedLookback = 100;
void main() {
if (Platform.isWindows || sizeOf<IntPtr>() == 4) {
// We don't yet support running tflite on Windows or 32-bit systems.
return;
}
LanguageModel model;
setUp(() {
model = LanguageModel.load(directory);
});
tearDown(() {
model.close();
});
test('calculates lookback', () {
expect(model.lookback, expectedLookback);
});
test('predict with defaults', () {
final tokens =
tokenize('if (list == null) { return; } for (final i = 0; i < list.');
final suggestions = model.predict(tokens);
expect(suggestions, hasLength(model.completions));
expect(suggestions.first, 'length');
});
test('predict with confidence scores', () {
final tokens =
tokenize('if (list == null) { return; } for (final i = 0; i < list.');
final suggestions = model.predictWithScores(tokens);
final best = suggestions.entries.first;
expect(best.key, 'length');
expect(best.value, greaterThan(0.8));
});
test('predict when no previous tokens', () {
final tokens = <String>[];
final suggestions = model.predict(tokens);
expect(suggestions, hasLength(model.completions));
expect(suggestions.first, isNotEmpty);
});
test('load fail', () {
try {
LanguageModel.load('doesnotexist');
fail('Failure to load language model should throw an exception');
} catch (e) {
expect(
e.toString(), equals('Invalid argument(s): Unable to create model.'));
}
});
}
/// Tokenizes the input string.
///
/// The input is split by word boundaries and trimmed of whitespace.
List<String> tokenize(String input) =>
input.split(RegExp(r'\b|\s')).map((t) => t.trim()).toList()
..removeWhere((t) => t.isEmpty);

View file

@ -13,6 +13,7 @@ import 'imported_reference_contributor_test.dart' as imported_ref_test;
import 'inherited_reference_contributor_test.dart' as inherited_ref_test;
import 'keyword_contributor_test.dart' as keyword_test;
import 'label_contributor_test.dart' as label_contributor_test;
import 'language_model_test.dart' as language_model_test;
import 'library_member_contributor_test.dart' as library_member_test;
import 'library_prefix_contributor_test.dart' as library_prefix_test;
import 'local_constructor_contributor_test.dart' as local_constructor_test;
@ -37,6 +38,7 @@ main() {
inherited_ref_test.main();
keyword_test.main();
label_contributor_test.main();
language_model_test.main();
library_member_test.main();
library_prefix_test.main();
local_constructor_test.main();