mirror of
https://github.com/dart-lang/sdk
synced 2024-09-16 00:39:49 +00:00
Implement pointer mixture network
Here we can see the mixture network is assigning probability mass to a new, project-specific name by reference https://i.imgur.com/6Zbs2qf.png. I also took this opportunity to decrease model size targeting 100M, in line with our original size goals. My initial strategy was to implement a separate pointer network in https://dart-review.googlesource.com/c/sdk/+/117005 but having a single network that can assign probability mass across local references and vocabulary lexemes is better since 1) only one network and model file 2) no need to coalesce predictions from multiple models Change-Id: I23cfc2ece61ce30bb69785149a5a6cf1604af18d Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/121461 Commit-Queue: Ari Aye <ariaye@google.com> Reviewed-by: Brian Wilkerson <brianwilkerson@google.com>
This commit is contained in:
parent
263bfd9635
commit
70a7ef3f58
2
DEPS
2
DEPS
|
@ -427,7 +427,7 @@ deps = {
|
|||
"packages": [
|
||||
{
|
||||
"package": "dart/language_model",
|
||||
"version": "EFtZ0Z5T822s4EUOOaWeiXUppRGKp5d9Z6jomJIeQYcC",
|
||||
"version": "9fJQZ0TrnAGQKrEtuL3-AXbUfPzYxqpN_OBHr9P4hE4C",
|
||||
}
|
||||
],
|
||||
"dep_type": "cipd",
|
||||
|
|
|
@ -12,7 +12,7 @@ import 'package:analysis_server/src/services/completion/dart/language_model.dart
|
|||
import 'package:analyzer/dart/analysis/features.dart';
|
||||
|
||||
/// Number of lookback tokens.
|
||||
const int _LOOKBACK = 100;
|
||||
const int _LOOKBACK = 50;
|
||||
|
||||
/// Minimum probability to prioritize model-only suggestion.
|
||||
const double _MODEL_RELEVANCE_CUTOFF = 0.5;
|
||||
|
|
|
@ -174,6 +174,12 @@ List<String> constructQuery(DartCompletionRequest request, int n) {
|
|||
size < n && token != null && !token.isEof;
|
||||
token = token.previous) {
|
||||
if (!token.isSynthetic && token is! ErrorToken) {
|
||||
// Omit the optional new keyword as we remove it at training time to
|
||||
// prevent model from suggesting it.
|
||||
if (token.lexeme == 'new') {
|
||||
continue;
|
||||
}
|
||||
|
||||
result.add(token.lexeme);
|
||||
size += 1;
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import 'package:tflite_native/tflite.dart' as tfl;
|
|||
/// Interface to TensorFlow-based Dart language model for next-token prediction.
|
||||
class LanguageModel {
|
||||
static const _defaultCompletions = 100;
|
||||
static final _numeric = RegExp(r'^\d+(.\d+)?$');
|
||||
|
||||
final tfl.Interpreter _interpreter;
|
||||
final Map<String, int> _word2idx;
|
||||
|
@ -63,48 +64,71 @@ class LanguageModel {
|
|||
/// Predicts the next token to follow a list of precedent tokens
|
||||
///
|
||||
/// Returns a list of tokens, sorted by most probable first.
|
||||
List<String> predict(Iterable<String> tokens) =>
|
||||
List<String> predict(List<String> tokens) =>
|
||||
predictWithScores(tokens).keys.toList();
|
||||
|
||||
/// Predicts the next token with confidence scores.
|
||||
///
|
||||
/// Returns an ordered map of tokens to scores, sorted by most probable first.
|
||||
Map<String, double> predictWithScores(Iterable<String> tokens) {
|
||||
Map<String, double> predictWithScores(List<String> tokens) {
|
||||
final tensorIn = _interpreter.getInputTensors().single;
|
||||
tensorIn.data = _transformInput(tokens);
|
||||
_interpreter.invoke();
|
||||
final tensorOut = _interpreter.getOutputTensors().single;
|
||||
return _transformOutput(tensorOut.data);
|
||||
return _transformOutput(tensorOut.data, tokens);
|
||||
}
|
||||
|
||||
/// Transforms tokens to data bytes that can be used as interpreter input.
|
||||
List<int> _transformInput(Iterable<String> tokens) {
|
||||
List<int> _transformInput(List<String> tokens) {
|
||||
// Replace out of vocabulary tokens.
|
||||
final sanitizedTokens = tokens
|
||||
.map((token) => _word2idx.containsKey(token) ? token : '<unknown>');
|
||||
|
||||
final sanitizedTokens = tokens.map((token) {
|
||||
if (_word2idx.containsKey(token)) {
|
||||
return token;
|
||||
}
|
||||
if (_numeric.hasMatch(token)) {
|
||||
return '<num>';
|
||||
}
|
||||
if (_isString(token)) {
|
||||
return '<str>';
|
||||
}
|
||||
return '<unk>';
|
||||
});
|
||||
// Get indexes (as floats).
|
||||
final indexes = Float32List(lookback)
|
||||
..setAll(0, sanitizedTokens.map((token) => _word2idx[token].toDouble()));
|
||||
|
||||
// Get bytes
|
||||
return Uint8List.view(indexes.buffer);
|
||||
}
|
||||
|
||||
/// Transforms interpreter output data to map of tokens to scores.
|
||||
Map<String, double> _transformOutput(List<int> databytes) {
|
||||
Map<String, double> _transformOutput(
|
||||
List<int> databytes, List<String> tokens) {
|
||||
// Get bytes.
|
||||
final bytes = Uint8List.fromList(databytes);
|
||||
|
||||
// Get scores (as floats)
|
||||
final probabilities = Float32List.view(bytes.buffer);
|
||||
|
||||
// Get indexes with scores, sorted by scores (descending)
|
||||
final entries = probabilities.asMap().entries.toList()
|
||||
..sort((a, b) => b.value.compareTo(a.value));
|
||||
final scores = Map<String, double>();
|
||||
probabilities.asMap().forEach((k, v) {
|
||||
// x in 0, 1, ..., |V| - 1 correspond to specific members of the vocabulary.
|
||||
// x in |V|, |V| + 1, ..., |V| + 49 are pointers to reference positions along the
|
||||
// network input.
|
||||
if (k >= _idx2word.length + tokens.length) {
|
||||
return;
|
||||
}
|
||||
final lexeme =
|
||||
k < _idx2word.length ? _idx2word[k] : tokens[k - _idx2word.length];
|
||||
final sanitized = lexeme.replaceAll('"', '\'');
|
||||
scores[sanitized] = (scores[sanitized] ?? 0.0) + v;
|
||||
});
|
||||
|
||||
// Get tokens with scores, limiting the length.
|
||||
return Map.fromEntries(entries.sublist(0, completions))
|
||||
.map((k, v) => MapEntry(_idx2word[k].replaceAll('"', '\''), v));
|
||||
final entries = scores.entries.toList()
|
||||
..sort((a, b) => b.value.compareTo(a.value));
|
||||
return Map.fromEntries(entries.sublist(0, completions));
|
||||
}
|
||||
|
||||
bool _isString(String token) {
|
||||
return token.indexOf('"') != -1 || token.indexOf("'") != -1;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ void main() {
|
|||
final tokens =
|
||||
tokenize('if (list == null) { return; } for (final i = 0; i < list.');
|
||||
final response = await ranking.makeRequest('predict', tokens);
|
||||
expect(response['data']['length'], greaterThan(0.95));
|
||||
expect(response['data']['length'], greaterThan(0.85));
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import 'package:test/test.dart';
|
|||
|
||||
final directory = path.join(File.fromUri(Platform.script).parent.path, '..',
|
||||
'..', '..', '..', 'language_model', 'lexeme');
|
||||
const expectedLookback = 100;
|
||||
const expectedLookback = 50;
|
||||
|
||||
void main() {
|
||||
if (sizeOf<IntPtr>() == 4) {
|
||||
|
@ -47,7 +47,7 @@ void main() {
|
|||
final suggestions = model.predictWithScores(tokens);
|
||||
final best = suggestions.entries.first;
|
||||
expect(best.key, 'length');
|
||||
expect(best.value, greaterThan(0.8));
|
||||
expect(best.value, greaterThan(0.85));
|
||||
expect(suggestions, hasLength(model.completions));
|
||||
});
|
||||
|
||||
|
|
Loading…
Reference in a new issue