mirror of
https://github.com/dart-lang/sdk
synced 2024-11-05 18:22:09 +00:00
Check in token-level language model via tflite ffi
Local test run after running `gclient sync`: ariaye@ariaye1:~/sdk/sdk$ dart pkg/analysis_server/test/services/completion/dart/language_model_test.dart 00:00 +0: calculates lookback INFO: Initialized TensorFlow Lite runtime. 00:00 +1: predict with defaults 00:01 +2: predict with confidence scores 00:03 +3: predict when no previous tokens 00:04 +4: All tests passed! Change-Id: I4181bea09cf8fec74d03bba4f83cd26dac818f30 Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/109662 Reviewed-by: Brian Wilkerson <brianwilkerson@google.com>
This commit is contained in:
parent
b9ab8efca6
commit
5d5a0c0164
6 changed files with 204 additions and 0 deletions
|
@ -100,6 +100,7 @@ test_process:third_party/pkg/test_process/lib
|
|||
test_reflective_loader:third_party/pkg/test_reflective_loader/lib
|
||||
test_runner:pkg/test_runner/lib
|
||||
testing:pkg/testing/lib
|
||||
tflite_native:third_party/pkg/tflite_native/lib
|
||||
typed_data:third_party/pkg/typed_data/lib
|
||||
unittest:third_party/pkg/unittest/lib
|
||||
usage:third_party/pkg/usage/lib
|
||||
|
|
13
DEPS
13
DEPS
|
@ -135,6 +135,7 @@ vars = {
|
|||
"term_glyph_tag": "1.0.1",
|
||||
"test_reflective_loader_tag": "0.1.8",
|
||||
"test_tag": "test-v1.6.4",
|
||||
"tflite_native_rev": "712b8a93fbb4caf83ffed37f154da88c2a517a91",
|
||||
"typed_data_tag": "1.1.6",
|
||||
"unittest_rev": "2b8375bc98bb9dc81c539c91aaea6adce12e1072",
|
||||
"usage_tag": "3.4.0",
|
||||
|
@ -360,6 +361,8 @@ deps = {
|
|||
Var("dart_git") + "term_glyph.git" + "@" + Var("term_glyph_tag"),
|
||||
Var("dart_root") + "/third_party/pkg/test":
|
||||
Var("dart_git") + "test.git" + "@" + Var("test_tag"),
|
||||
Var("dart_root") + "/third_party/pkg/tflite_native":
|
||||
Var("dart_git") + "tflite_native.git" + "@" + Var("tflite_native_rev"),
|
||||
Var("dart_root") + "/third_party/pkg/test_descriptor":
|
||||
Var("dart_git") + "test_descriptor.git" + "@" + Var("test_descriptor_tag"),
|
||||
Var("dart_root") + "/third_party/pkg/test_process":
|
||||
|
@ -399,6 +402,16 @@ deps = {
|
|||
"dep_type": "cipd",
|
||||
},
|
||||
|
||||
Var("dart_root") + "/pkg/analysis_server/language_model": {
|
||||
"packages": [
|
||||
{
|
||||
"package": "dart/language_model",
|
||||
"version": "KB68QHR1SKtopACaf3TFcu9MusRbwWqs0L1m_urGLL4C",
|
||||
}
|
||||
],
|
||||
"dep_type": "cipd",
|
||||
},
|
||||
|
||||
Var("dart_root") + "/buildtools": {
|
||||
"packages": [
|
||||
{
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
// Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file
|
||||
// for details. All rights reserved. Use of this source code is governed by a
|
||||
// BSD-style license that can be found in the LICENSE file.
|
||||
|
||||
import 'dart:io';
|
||||
import 'dart:convert';
|
||||
import 'dart:typed_data';
|
||||
|
||||
import 'package:path/path.dart' as path;
|
||||
import 'package:quiver/check.dart';
|
||||
import 'package:tflite_native/tflite.dart' as tfl;
|
||||
|
||||
/// Interface to TensorFlow-based Dart language model for next-token prediction.
|
||||
class LanguageModel {
|
||||
static const _defaultCompletions = 100;
|
||||
|
||||
final tfl.Interpreter _interpreter;
|
||||
final Map<String, int> _word2idx;
|
||||
final Map<int, String> _idx2word;
|
||||
final int _lookback;
|
||||
|
||||
LanguageModel._(
|
||||
this._interpreter, this._word2idx, this._idx2word, this._lookback);
|
||||
|
||||
/// Number of previous tokens to look at during predictions.
|
||||
int get lookback => _lookback;
|
||||
|
||||
/// Number of completion results to return during predictions.
|
||||
int get completions => _defaultCompletions;
|
||||
|
||||
/// Load model from directory.
|
||||
factory LanguageModel.load(String directory) {
|
||||
// Load model.
|
||||
final interpreter =
|
||||
tfl.Interpreter.fromFile(path.join(directory, 'model.tflite'));
|
||||
interpreter.allocateTensors();
|
||||
|
||||
// Load word2idx mapping for input.
|
||||
final word2idx = json
|
||||
.decode(File(path.join(directory, 'word2idx.json')).readAsStringSync())
|
||||
.cast<String, int>();
|
||||
|
||||
// Load idx2word mapping for output.
|
||||
final idx2word = json
|
||||
.decode(File(path.join(directory, 'idx2word.json')).readAsStringSync())
|
||||
.map<int, String>((k, v) => MapEntry<int, String>(int.parse(k), v));
|
||||
|
||||
// Get lookback size from model input tensor shape.
|
||||
final tensorShape = interpreter.getInputTensors().single.shape;
|
||||
checkArgument(tensorShape.length == 2 && tensorShape.first == 1,
|
||||
message:
|
||||
'tensor shape $tensorShape does not match the expected [1, X]');
|
||||
final lookback = tensorShape.last;
|
||||
|
||||
return LanguageModel._(interpreter, word2idx, idx2word, lookback);
|
||||
}
|
||||
|
||||
/// Tear down the interpreter.
|
||||
void close() {
|
||||
_interpreter.delete();
|
||||
}
|
||||
|
||||
/// Predicts the next token to follow a list of precedent tokens
|
||||
///
|
||||
/// Returns a list of tokens, sorted by most probable first.
|
||||
List<String> predict(Iterable<String> tokens) =>
|
||||
predictWithScores(tokens).keys.toList();
|
||||
|
||||
/// Predicts the next token with confidence scores.
|
||||
///
|
||||
/// Returns an ordered map of tokens to scores, sorted by most probable first.
|
||||
Map<String, double> predictWithScores(Iterable<String> tokens) {
|
||||
final tensorIn = _interpreter.getInputTensors().single;
|
||||
tensorIn.data = _transformInput(tokens);
|
||||
_interpreter.invoke();
|
||||
final tensorOut = _interpreter.getOutputTensors().single;
|
||||
return _transformOutput(tensorOut.data);
|
||||
}
|
||||
|
||||
/// Transforms tokens to data bytes that can be used as interpreter input.
|
||||
List<int> _transformInput(Iterable<String> tokens) {
|
||||
// Replace out of vocabulary tokens.
|
||||
final sanitizedTokens = tokens
|
||||
.map((token) => _word2idx.containsKey(token) ? token : '<unknown>');
|
||||
|
||||
// Get indexes (as floats).
|
||||
final indexes = Float32List(lookback)
|
||||
..setAll(0, sanitizedTokens.map((token) => _word2idx[token].toDouble()));
|
||||
|
||||
// Get bytes
|
||||
return Uint8List.view(indexes.buffer);
|
||||
}
|
||||
|
||||
/// Transforms interpreter output data to map of tokens to scores.
|
||||
Map<String, double> _transformOutput(List<int> databytes) {
|
||||
// Get bytes.
|
||||
final bytes = Uint8List.fromList(databytes);
|
||||
|
||||
// Get scores (as floats)
|
||||
final probabilities = Float32List.view(bytes.buffer);
|
||||
|
||||
// Get indexes with scores, sorted by scores (descending)
|
||||
final entries = probabilities.asMap().entries.toList()
|
||||
..sort((a, b) => b.value.compareTo(a.value));
|
||||
|
||||
// Get tokens with scores, limiting the length.
|
||||
return Map.fromEntries(entries.sublist(0, completions))
|
||||
.map((k, v) => MapEntry(_idx2word[k], v));
|
||||
}
|
||||
}
|
|
@ -19,6 +19,7 @@ dependencies:
|
|||
source_span: any
|
||||
package_config: any
|
||||
path: any
|
||||
tflite_native: any
|
||||
watcher: any
|
||||
yaml: any
|
||||
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
// Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file
|
||||
// for details. All rights reserved. Use of this source code is governed by a
|
||||
// BSD-style license that can be found in the LICENSE file.
|
||||
|
||||
import 'dart:ffi';
|
||||
import 'dart:io';
|
||||
|
||||
import 'package:analysis_server/src/services/completion/dart/language_model.dart';
|
||||
import 'package:path/path.dart' as path;
|
||||
import 'package:test/test.dart';
|
||||
import 'package:test_reflective_loader/test_reflective_loader.dart';
|
||||
|
||||
final directory =
|
||||
Platform.script.resolve('../../../../language_model/lexeme').path;
|
||||
const expectedLookback = 100;
|
||||
|
||||
void main() {
|
||||
if (Platform.isWindows || sizeOf<IntPtr>() == 4) {
|
||||
// We don't yet support running tflite on Windows or 32-bit systems.
|
||||
return;
|
||||
}
|
||||
|
||||
LanguageModel model;
|
||||
|
||||
setUp(() {
|
||||
model = LanguageModel.load(directory);
|
||||
});
|
||||
|
||||
tearDown(() {
|
||||
model.close();
|
||||
});
|
||||
|
||||
test('calculates lookback', () {
|
||||
expect(model.lookback, expectedLookback);
|
||||
});
|
||||
|
||||
test('predict with defaults', () {
|
||||
final tokens =
|
||||
tokenize('if (list == null) { return; } for (final i = 0; i < list.');
|
||||
final suggestions = model.predict(tokens);
|
||||
expect(suggestions, hasLength(model.completions));
|
||||
expect(suggestions.first, 'length');
|
||||
});
|
||||
|
||||
test('predict with confidence scores', () {
|
||||
final tokens =
|
||||
tokenize('if (list == null) { return; } for (final i = 0; i < list.');
|
||||
final suggestions = model.predictWithScores(tokens);
|
||||
final best = suggestions.entries.first;
|
||||
expect(best.key, 'length');
|
||||
expect(best.value, greaterThan(0.8));
|
||||
});
|
||||
|
||||
test('predict when no previous tokens', () {
|
||||
final tokens = <String>[];
|
||||
final suggestions = model.predict(tokens);
|
||||
expect(suggestions, hasLength(model.completions));
|
||||
expect(suggestions.first, isNotEmpty);
|
||||
});
|
||||
|
||||
test('load fail', () {
|
||||
try {
|
||||
LanguageModel.load('doesnotexist');
|
||||
fail('Failure to load language model should throw an exception');
|
||||
} catch (e) {
|
||||
expect(
|
||||
e.toString(), equals('Invalid argument(s): Unable to create model.'));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Tokenizes the input string.
|
||||
///
|
||||
/// The input is split by word boundaries and trimmed of whitespace.
|
||||
List<String> tokenize(String input) =>
|
||||
input.split(RegExp(r'\b|\s')).map((t) => t.trim()).toList()
|
||||
..removeWhere((t) => t.isEmpty);
|
|
@ -13,6 +13,7 @@ import 'imported_reference_contributor_test.dart' as imported_ref_test;
|
|||
import 'inherited_reference_contributor_test.dart' as inherited_ref_test;
|
||||
import 'keyword_contributor_test.dart' as keyword_test;
|
||||
import 'label_contributor_test.dart' as label_contributor_test;
|
||||
import 'language_model_test.dart' as language_model_test;
|
||||
import 'library_member_contributor_test.dart' as library_member_test;
|
||||
import 'library_prefix_contributor_test.dart' as library_prefix_test;
|
||||
import 'local_constructor_contributor_test.dart' as local_constructor_test;
|
||||
|
@ -37,6 +38,7 @@ main() {
|
|||
inherited_ref_test.main();
|
||||
keyword_test.main();
|
||||
label_contributor_test.main();
|
||||
language_model_test.main();
|
||||
library_member_test.main();
|
||||
library_prefix_test.main();
|
||||
local_constructor_test.main();
|
||||
|
|
Loading…
Reference in a new issue