Implement pointer mixture network

Here we can see the mixture network is assigning probability mass
to a new, project-specific name by reference https://i.imgur.com/6Zbs2qf.png.

I also took this opportunity to decrease model size targeting 100M, in line
with our original size goals.

My initial strategy was to implement a separate pointer network
in https://dart-review.googlesource.com/c/sdk/+/117005 but having
a single network that can assign probability mass across local
references and vocabulary lexemes is better since

1) only one network and model file
2) no need to coalesce predictions from multiple models

Change-Id: I23cfc2ece61ce30bb69785149a5a6cf1604af18d
Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/121461
Commit-Queue: Ari Aye <ariaye@google.com>
Reviewed-by: Brian Wilkerson <brianwilkerson@google.com>
This commit is contained in:
Ari Aye 2019-10-14 18:25:56 +00:00 committed by commit-bot@chromium.org
parent 263bfd9635
commit 70a7ef3f58
6 changed files with 50 additions and 20 deletions

2
DEPS
View file

@ -427,7 +427,7 @@ deps = {
"packages": [
{
"package": "dart/language_model",
"version": "EFtZ0Z5T822s4EUOOaWeiXUppRGKp5d9Z6jomJIeQYcC",
"version": "9fJQZ0TrnAGQKrEtuL3-AXbUfPzYxqpN_OBHr9P4hE4C",
}
],
"dep_type": "cipd",

View file

@ -12,7 +12,7 @@ import 'package:analysis_server/src/services/completion/dart/language_model.dart
import 'package:analyzer/dart/analysis/features.dart';
/// Number of lookback tokens.
const int _LOOKBACK = 100;
const int _LOOKBACK = 50;
/// Minimum probability to prioritize model-only suggestion.
const double _MODEL_RELEVANCE_CUTOFF = 0.5;

View file

@ -174,6 +174,12 @@ List<String> constructQuery(DartCompletionRequest request, int n) {
size < n && token != null && !token.isEof;
token = token.previous) {
if (!token.isSynthetic && token is! ErrorToken) {
// Omit the optional new keyword as we remove it at training time to
// prevent model from suggesting it.
if (token.lexeme == 'new') {
continue;
}
result.add(token.lexeme);
size += 1;
}

View file

@ -13,6 +13,7 @@ import 'package:tflite_native/tflite.dart' as tfl;
/// Interface to TensorFlow-based Dart language model for next-token prediction.
class LanguageModel {
static const _defaultCompletions = 100;
static final _numeric = RegExp(r'^\d+(.\d+)?$');
final tfl.Interpreter _interpreter;
final Map<String, int> _word2idx;
@ -63,48 +64,71 @@ class LanguageModel {
/// Predicts the next token to follow a list of precedent tokens
///
/// Returns a list of tokens, sorted by most probable first.
List<String> predict(Iterable<String> tokens) =>
List<String> predict(List<String> tokens) =>
predictWithScores(tokens).keys.toList();
/// Predicts the next token with confidence scores.
///
/// Returns an ordered map of tokens to scores, sorted by most probable first.
Map<String, double> predictWithScores(Iterable<String> tokens) {
Map<String, double> predictWithScores(List<String> tokens) {
final tensorIn = _interpreter.getInputTensors().single;
tensorIn.data = _transformInput(tokens);
_interpreter.invoke();
final tensorOut = _interpreter.getOutputTensors().single;
return _transformOutput(tensorOut.data);
return _transformOutput(tensorOut.data, tokens);
}
/// Transforms tokens to data bytes that can be used as interpreter input.
List<int> _transformInput(Iterable<String> tokens) {
List<int> _transformInput(List<String> tokens) {
// Replace out of vocabulary tokens.
final sanitizedTokens = tokens
.map((token) => _word2idx.containsKey(token) ? token : '<unknown>');
final sanitizedTokens = tokens.map((token) {
if (_word2idx.containsKey(token)) {
return token;
}
if (_numeric.hasMatch(token)) {
return '<num>';
}
if (_isString(token)) {
return '<str>';
}
return '<unk>';
});
// Get indexes (as floats).
final indexes = Float32List(lookback)
..setAll(0, sanitizedTokens.map((token) => _word2idx[token].toDouble()));
// Get bytes
return Uint8List.view(indexes.buffer);
}
/// Transforms interpreter output data to map of tokens to scores.
Map<String, double> _transformOutput(List<int> databytes) {
Map<String, double> _transformOutput(
List<int> databytes, List<String> tokens) {
// Get bytes.
final bytes = Uint8List.fromList(databytes);
// Get scores (as floats)
final probabilities = Float32List.view(bytes.buffer);
// Get indexes with scores, sorted by scores (descending)
final entries = probabilities.asMap().entries.toList()
..sort((a, b) => b.value.compareTo(a.value));
final scores = Map<String, double>();
probabilities.asMap().forEach((k, v) {
// x in 0, 1, ..., |V| - 1 correspond to specific members of the vocabulary.
// x in |V|, |V| + 1, ..., |V| + 49 are pointers to reference positions along the
// network input.
if (k >= _idx2word.length + tokens.length) {
return;
}
final lexeme =
k < _idx2word.length ? _idx2word[k] : tokens[k - _idx2word.length];
final sanitized = lexeme.replaceAll('"', '\'');
scores[sanitized] = (scores[sanitized] ?? 0.0) + v;
});
// Get tokens with scores, limiting the length.
return Map.fromEntries(entries.sublist(0, completions))
.map((k, v) => MapEntry(_idx2word[k].replaceAll('"', '\''), v));
final entries = scores.entries.toList()
..sort((a, b) => b.value.compareTo(a.value));
return Map.fromEntries(entries.sublist(0, completions));
}
bool _isString(String token) {
return token.indexOf('"') != -1 || token.indexOf("'") != -1;
}
}

View file

@ -20,7 +20,7 @@ void main() {
final tokens =
tokenize('if (list == null) { return; } for (final i = 0; i < list.');
final response = await ranking.makeRequest('predict', tokens);
expect(response['data']['length'], greaterThan(0.95));
expect(response['data']['length'], greaterThan(0.85));
});
}

View file

@ -11,7 +11,7 @@ import 'package:test/test.dart';
final directory = path.join(File.fromUri(Platform.script).parent.path, '..',
'..', '..', '..', 'language_model', 'lexeme');
const expectedLookback = 100;
const expectedLookback = 50;
void main() {
if (sizeOf<IntPtr>() == 4) {
@ -47,7 +47,7 @@ void main() {
final suggestions = model.predictWithScores(tokens);
final best = suggestions.entries.first;
expect(best.key, 'length');
expect(best.value, greaterThan(0.8));
expect(best.value, greaterThan(0.85));
expect(suggestions, hasLength(model.completions));
});