lucene 9.1.0 HNSW的源码解析(修订版)
修订部分
在指定层带截断查找
lucene版本号修正
前提
向量数据添加过程
代码示例
public static void main(String[] args) throws IOException, InterruptedException {
Directory directory = FSDirectory.open(Paths.get("D:\\my_index_knn"));
StandardAnalyzer analyzer = new StandardAnalyzer();
IndexWriterConfig indexWriterConfig = new IndexWriterConfig(analyzer);
indexWriterConfig.setUseCompoundFile(false);
IndexWriter indexWriter = new IndexWriter(directory, indexWriterConfig);
for (int i = 1; i <= 500; i ++) {
Document document = new Document();
// 添加向量字段
document.add(new KnnVectorField("vector1", TestDataGenerator.generateData(128), VectorSimilarityFunction.EUCLIDEAN));
document.add(new KnnVectorField("vector2", TestDataGenerator.generateData(128), VectorSimilarityFunction.EUCLIDEAN));
indexWriter.addDocument(document);
if (i % 100 == 0) {
indexWriter.flush();
indexWriter.commit();
}
}
indexWriter.flush();
indexWriter.commit();
IndexReader reader = DirectoryReader.open(indexWriter);
IndexSearcher searcher = new IndexSearcher(reader);
// 检索
KnnVectorQuery knnVectorQuery = new KnnVectorQuery("vector1", TestDataGenerator.generateData(128), 10);
TopDocs search = searcher.search(knnVectorQuery, 10);
}
// 字段的信息,向量相关的是:维度和距离度量
private final FieldInfo fieldInfo;
// segment当前使用的内存,会配合阈值使用,是触发flush的一个条件
private final Counter iwBytesUsed;
// 存储向量数据
private final List<float[]> vectors = new ArrayList<>();
// 位图,不一定每个文档都有这个字段,用位图来记录包含这个字段的文档id
private final DocsWithFieldSet docsWithField;
public void addValue(int docID, float[] vectorValue) {
// 向量字段,每个文档只能有一个
if (docID == lastDocID) {
throw new IllegalArgumentException(
"VectorValuesField \""
+ fieldInfo.name
+ "\" appears more than once in this document (only one value is allowed per field)");
}
。。。(省略)
// 记录docId
docsWithField.add(docID);
// 存储向量数据
vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length));
// 更新当前段的内存数据
updateBytesUsed();
lastDocID = docID;
}
HNSW查找和构建的工具类
NeighborArray
// 邻居个数
private int size;
// 所有邻居到节点的距离
float[] score;
// 所有邻居节点的编号
int[] node;
BoundsChecker
public abstract class BoundsChecker {
// 边界值
float bound;
// 如果边界更优则更新边界值。
// 如果是最小值检查,sample比当前最小值小,则更新边界。
// 如果是最大值检查,sample比当前最大值大,则更新边界。
public abstract void update(float sample);
// 设置边界
public void set(float sample) {
bound = sample;
}
// 边界检查
// 如果是最小值检查,则判断sample是否大于当前的最小值边界
// 如果是最大值检查,则判断sample是否小于当前的最大值边界
public abstract boolean check(float sample);
// 根据是否逆序创建Min或者Max,Min和Max是BoundsChecker内部类
public static BoundsChecker create(boolean reversed) {
if (reversed) {
return new Min();
} else {
return new Max();
}
}
// 最大值检查工具类
public static class Max extends BoundsChecker {
Max() {
bound = Float.NEGATIVE_INFINITY;
}
@Override
public void update(float sample) {
if (sample > bound) {
bound = sample;
}
}
@Override
public boolean check(float sample) {
return sample < bound;
}
}
// 最小值检查工具类
public static class Min extends BoundsChecker {
Min() {
bound = Float.POSITIVE_INFINITY;
}
@Override
public void update(float sample) {
if (sample < bound) {
bound = sample;
}
}
@Override
public boolean check(float sample) {
return sample > bound;
}
}
}
LongHeap
// 带溢出检查的方式插入
public boolean insertWithOverflow(long value) {
// 如果已经达到堆的大小限制
if (size >= maxSize) {
// 因为是最小堆,所以比堆顶小就直接返回插入失败
if (value < heap[1]) {
return false;
}
// 走到这里说明比堆顶大,直接替换堆顶元素,然后调整堆重新成为最小堆
updateTop(value);
return true;
}
// 没有达到堆大小限制直接插入
push(value);
return true;
}
NeighborQueue
// 因为LongHeap是最小堆,所以如果是需要最大堆的功能,则需要做倒序转化
private static enum Order {
// 自然顺序:从小到大
NATURAL {
@Override
long apply(long v) {
return v;
}
},
// 倒序:从大到小
REVERSED {
@Override
long apply(long v) {
return -1 - v;
}
};
// 值转化
// 自然顺序,不需要转化,也就是最小堆
// 逆序,则存-1 - v,相当于是最大堆
abstract long apply(long v);
}
// 存储距离和节点编号的复合体,可以简单理解成是节点和目标节点的距离
private final LongHeap heap;
// 自然顺序还是逆序
private final Order order;
// 遍历过的节点
private int visitedCount;
// 是否是提前截断导致查找停止
private boolean incomplete;
// initialSize:堆的大小限制
// reversed:是否逆序,控制是最大堆还是最小堆
public NeighborQueue(int initialSize, boolean reversed) {
this.heap = new LongHeap(initialSize);
this.order = reversed ? Order.REVERSED : Order.NATURAL;
}
// 添加新的邻居节点及其距离,经过编码之后插入堆中
public void add(int newNode, float newScore) {
heap.push(encode(newNode, newScore));
}
// 带溢出检查的方式添加新的邻居节点及其距离,编码之后使用堆的insertWithOverflow方法插入
public boolean insertWithOverflow(int newNode, float newScore) {
return heap.insertWithOverflow(encode(newNode, newScore));
}
// 高32位存储的是距离,低32位存的是节点编号
private long encode(int node, float score) {
return order.apply((((long) NumericUtils.floatToSortableInt(score)) << 32) | node);
}
// 删除堆顶元素,返回的是节点编号,直接long转int就行,低32位就是编号
public int pop() {
return (int) order.apply(heap.pop());
}
// 返回堆顶节点的编号
public int topNode() {
return (int) order.apply(heap.top());
}
// 返回堆顶节点的距离
public float topScore() {
return NumericUtils.sortableIntToFloat((int) (order.apply(heap.top()) >> 32));
}
HnswGraph
public abstract class HnswGraph {
// 定位到target节点的邻居的存储位置,然后可以调用nextNeighbor遍历所有的邻居
public abstract void seek(int level, int target) throws IOException;
// HNSW中的节点总数,其实也是对底层的节点总数,因为最底层包含了所有的节点
public abstract int size();
// 获取邻居节点,如果遍历结束或者没有邻居返回NO_MORE_DOCS
public abstract int nextNeighbor() throws IOException;
// 返回HNSW的层数
public abstract int numLevels() throws IOException;
// 搜索时,位于顶层的起始遍历节点,只有一个起始节点entry point
public abstract int entryNode() throws IOException;
// 指定层的节点迭代器,可以通过迭代器获取某一层的所有节点
public abstract NodesIterator getNodesOnLevel(int level) throws IOException;
}
// 每个节点最大的邻居个数
private final int maxConn;
// 总层数
private int numLevels;
// 检索的初始节点,在最顶层。在论文中表示为entry point
private int entryNode;
// 每一层的节点
// 默认第0层是所有的节点,因此不需要存储数据,nodesByLevel.get(0) == null
private final List<int[]> nodesByLevel;
// 每一层节点的邻居
private final List<List<NeighborArray>> graph;
// 迭代遍历邻居用的
private int upto;
private NeighborArray cur;
// levelOfFirstNode是随机生成的,代表的是初始的层编号,注意层编号是从0开始的。
OnHeapHnswGraph(int maxConn, int levelOfFirstNode) {
this.maxConn = maxConn;
this.numLevels = levelOfFirstNode + 1;
this.graph = new ArrayList<>(numLevels);
// 默认是第一个向量作为顶层的节点,也是HNSW构建的起始点
this.entryNode = 0;
// 为每一层初始化邻居列表
for (int i = 0; i < numLevels; i++) {
graph.add(new ArrayList<>());
// 根据经验值初始化列表大小
graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4)));
}
this.nodesByLevel = new ArrayList<>(numLevels);
// 第0层直接设置为null,代表所有的节点
nodesByLevel.add(null);
// 每一层都加上初始节点,因为最顶层的节点肯定在所有层中都存在
for (int l = 1; l < numLevels; l++) {
nodesByLevel.add(new int[] {0});
}
}
// level:要加入哪一层,是新加入节点通过随机函数产生的
// node:加入的节点编码
public void addNode(int level, int node) {
// 因为第0层默认是所有节点都存在的,所以只记录0层以上的
if (level > 0) {
// 如果要加入的层大于目前的最高层
if (level >= numLevels) {
// 超出目前最高层的每一层都加上该节点
for (int i = numLevels; i <= level; i++) {
graph.add(new ArrayList<>());
nodesByLevel.add(new int[] {node});
}
// 更新层数,层编号是从0开始,所以加1
numLevels = level + 1;
// 更新初始节点,这是唯一的更新入口,由此可见 entryNode 是最高层的第一个节点
entryNode = node;
} else {
// 当前层中加入节点,如果超出数组大小,有扩容处理
int[] nodes = nodesByLevel.get(level);
int idx = graph.get(level).size();
if (idx < nodes.length) {
nodes[idx] = node;
} else {
nodes = ArrayUtil.grow(nodes);
nodes[idx] = node;
nodesByLevel.set(level, nodes);
}
}
}
// 新加入的节点无论在那一层,都为这一层初始化一个邻居容器
graph.get(level).add(new NeighborArray(maxConn + 1));
}
HNSW查找
// 距离度量方式
private final VectorSimilarityFunction similarityFunction;
// 最近邻的候选者。以欧式距离为例,则是一个最小堆,队顶元素是距离目标节点最近的节点
private final NeighborQueue candidates;
// 用来标记访问过的节点位图:因为邻居的邻居可能也是自己的邻居,相同节点不需要重复访问
private final BitSet visited;
在指定层带截断的查找
private NeighborQueue searchLevel(
float[] query, // 要查找的目标向量
int topK,
int level, // 在那一层查找
final int[] eps, // 起始遍历节点的编号列表
RandomAccessVectorValues vectors, // 待检索的向量集合
HnswGraph graph,
Bits acceptOrds,// 相当于是向量编号白名单,用来过滤查找结果。lucene中的段是不可变的,如果段中的数据被删除,
// 真正的数据不会被删除,而是用位图记录起来,acceptOrds相当于是记录存活的有效向量id
int visitedLimit) // 正常是访问迭代到最近邻,这个参数可以控制提前停止迭代,当然找到的一般不是最优结果
throws IOException {
int size = graph.size();
// 最终的topK结果,注意是带大小限制的堆,以欧式距离为例,是个大顶堆
NeighborQueue results = new NeighborQueue(topK, similarityFunction.reversed);
// 初始化,清空 HnswGraphSearcher 的两个成员变量,候选者(candidates) 和已经访问过的节点位图(visited)
clearScratchState();
// 记录当前访问过的节点个数,用来和visitedLimit判断是否截断,停止搜索
int numVisited = 0;
// 遍历起始节点,获取初始的候选节点和结果集
for (int ep : eps) {
// 如果没有访问过
if (visited.getAndSet(ep) == false) {
// 是否要截断
if (numVisited >= visitedLimit) {
// 截断标记
results.markIncomplete();
break;
}
// 计算entry point和目标节点的距离
float score = similarityFunction.compare(query, vectors.vectorValue(ep));
numVisited++;
// 加入候选者堆中
candidates.add(ep, score);
if (acceptOrds == null || acceptOrds.get(ep)) {
// 如果向量在白名单中,则加入最终结果的堆中
results.add(ep, score);
}
}
}
// 以下的流程就是判断候选结果中的最近节点的距离是不是都比结果集中的大,如果是,说明已经找到了最近邻topK,则返回结果集,否则继续查找候选集中的节点的邻居,不断迭代。
// 以欧式距离为例,设置一个最小值边界检查,如果待检查的值大于这个边界,返回true。
// 把当前结果集中的最大距离设置为边界
BoundsChecker bound = BoundsChecker.create(similarityFunction.reversed);
if (results.size() >= topK) {
bound.set(results.topScore());
}
// 遍历候选节点堆(最小堆),也就是遍历是从近到远,直到没有候选者或者截断发生
while (candidates.size() > 0 && results.incomplete() == false) {
// 获取堆顶节点,是候选列表中的最近节点
float topCandidateScore = candidates.topScore();
// 如果候选列表中的距离最近节点都没有满足边界要求,则结束迭代
if (bound.check(topCandidateScore)) {
break;
}
int topCandidateNode = candidates.pop();
// 定位到候选节点的邻居
graph.seek(level, topCandidateNode);
int friendOrd;
// 遍历候选节点的邻居
while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) {
// 访问过则忽略
if (visited.getAndSet(friendOrd)) {
continue;
}
// 截断发生
if (numVisited >= visitedLimit) {
results.markIncomplete();
break;
}
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
numVisited++;
// 如果节点比结果集中的堆顶节点要近,这个条件可能造成最后结果没有topK,我觉得是个bug,应该是bound.check(score) == false || results.size() < topK
if (bound.check(score) == false) {
// 当前遍历节点加入候选列表
candidates.add(friendOrd, score);
if (acceptOrds == null || acceptOrds.get(friendOrd)) {
if (results.insertWithOverflow(friendOrd, score) && results.size() >= topK) {
// 如果超出结果集的大小限制,说明堆顶节点被替换,则重新设置堆顶节点的距离作为边界
bound.set(results.topScore());
}
}
}
}
}
// 前面从eps初始化的时候,可能超过topK
while (results.size() > topK) {
results.pop();
}
results.setVisitedCount(numVisited);
return results;
}
在指定层不带截断的查找
NeighborQueue searchLevel(
float[] query,
int topK,
int level,
final int[] eps,
RandomAccessVectorValues vectors,
HnswGraph graph)
throws IOException {
return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE);
}
HNSW全局检索使用的接口
public static NeighborQueue search(
float[] query, // 待查找的目标向量
int topK,
RandomAccessVectorValues vectors, // 目标向量集合
VectorSimilarityFunction similarityFunction,
HnswGraph graph, // HNSW图,主要是层次信息和节点的邻居信息
Bits acceptOrds, // 相当于是向量编号白名单,用来过滤查找结果。lucene中的段是不可变的,如果段中的数据被删除,
// 真正的数据不会被删除,而是用位图记录起来,acceptOrds相当于是记录存活的有效向量id
int visitedLimit) // 正常是访问迭代到最近邻,这个参数可以控制提前停止迭代,当然找到的一般不是最优结果
throws IOException {
// 因为 graphSearcher 不是线程安全的,所以使用局部变量
HnswGraphSearcher graphSearcher =
new HnswGraphSearcher(
similarityFunction,
new NeighborQueue(topK, similarityFunction.reversed == false),
new SparseFixedBitSet(vectors.size()));
// 存储最后的结果
NeighborQueue results;
// 起始遍历的节点为最顶层的entry point
int[] eps = new int[] {graph.entryNode()};
int numVisited = 0;
// 从最顶层开始遍历直到倒数第二层,每一层中找到一个最近邻节点作为下一层的起始点
// 这是一个快速靠近目标节点的过程
for (int level = graph.numLevels() - 1; level >= 1; level--) {
// 调用了带截断的在指定层查找方法
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
eps[0] = results.pop();
numVisited += results.visitedCount();
visitedLimit -= results.visitedCount();
}
// 最后一层包含了所有的节点,从最后一层中找到真正的最近邻topK
results =
graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
results.setVisitedCount(results.visitedCount() + numVisited);
return results;
}
HNSW构建
触发构建时机
构建的入口
private void writeVectors(SegmentWriteState state, Sorter.DocMap sortMap) throws IOException {
// KnnVectorsWriter是个抽象类,在lucene9.1.0中的实现是Lucene91HnswVectorsWriter。
KnnVectorsWriter knnVectorsWriter = null;
boolean success = false;
try {
for (int i = 0; i < fieldHash.length; i++) {
PerField perField = fieldHash[i];
while (perField != null) {
。。。(一些判断和knnVectorsWriter的初始化)
// 真正执行flush的逻辑
perField.vectorValuesWriter.flush(sortMap, knnVectorsWriter);
perField.vectorValuesWriter = null;
。。。(一些判断)
}
}
if (knnVectorsWriter != null) {
// 为相关索引文件添加注脚,主要是校验码
knnVectorsWriter.finish();
}
success = true;
} finally {
if (success) {
IOUtils.close(knnVectorsWriter);
} else {
IOUtils.closeWhileHandlingException(knnVectorsWriter);
}
}
}
public void flush(Sorter.DocMap sortMap, KnnVectorsWriter knnVectorsWriter) throws IOException {
// KnnVectorsReader 中最主要的方法是getVectorValues获取所有待构建的向量数据
// getVectorValues 返回的也是封装了向量数据,维度等信息的BufferedVectorValues
KnnVectorsReader knnVectorsReader =
new KnnVectorsReader() {
@Override
public long ramBytesUsed() {
return 0;
}
@Override
public void close() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public void checkIntegrity() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public VectorValues getVectorValues(String field) throws IOException {
VectorValues vectorValues =
new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension());
return sortMap != null ? new SortingVectorValues(vectorValues, sortMap) : vectorValues;
}
@Override
public TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
throw new UnsupportedOperationException();
}
};
// 核心逻辑
knnVectorsWriter.writeField(fieldInfo, knnVectorsReader);
}
// 持久化的相关文件的输出流
// meta:向量元信息文件输出流
// vectorData: 向量数据文件输出流
// vectorIndex: 向量索引文件输出流(存储邻居信息)
private final IndexOutput meta, vectorData, vectorIndex;
// segment中的文档总个数
private final int maxDoc;
// 每个节点邻居的上限
private final int maxConn;
// 从第0层到Math.min(nodeLevel, curMaxLevel),在每一层中查询最近邻的候选个数
private final int beamWidth;
// 是否成功构建完成
private boolean finished;
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
throws IOException {
// 字节对齐(https://www.thinbug.com/q/47510783)
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
// 前面提到的BufferedVectorValues,保存了向量信息
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
// 临时文件,用来写所有的向量数据,如果hnsw构建失败,则数据还在。
IndexOutput tempVectorData =
segmentWriteState.directory.createTempOutput(
vectorData.getName(), "temp", segmentWriteState.context);
IndexInput vectorDataInput = null;
boolean success = false;
try {
// 向量数据写入临时文件
DocsWithFieldSet docsWithField = writeVectorData(tempVectorData, vectors);
CodecUtil.writeFooter(tempVectorData);
IOUtils.close(tempVectorData);
// 将临时文件中的数据拷贝到真正的segment中的向量数据文件
vectorDataInput =
segmentWriteState.directory.openInput(
tempVectorData.getName(), segmentWriteState.context);
vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
CodecUtil.retrieveChecksum(vectorDataInput);
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
long vectorIndexOffset = vectorIndex.getFilePointer();
// 又把向量数据重新封装到OffHeapVectorValues中
Lucene91HnswVectorsReader.OffHeapVectorValues offHeapVectors =
new Lucene91HnswVectorsReader.OffHeapVectorValues(
vectors.dimension(), docsWithField.cardinality(), null, vectorDataInput);
OnHeapHnswGraph graph =
offHeapVectors.size() == 0
? null
// 写入HNSW图结构,其中包含了构建流程
: writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction());
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
// 写向量元信息
writeMeta(
fieldInfo,
vectorDataOffset,
vectorDataLength,
vectorIndexOffset,
vectorIndexLength,
docsWithField,
graph);
success = true;
} finally {
IOUtils.close(vectorDataInput);
if (success) {
segmentWriteState.directory.deleteFile(tempVectorData.getName());
} else {
IOUtils.closeWhileHandlingException(tempVectorData);
IOUtils.deleteFilesIgnoringExceptions(
segmentWriteState.directory, tempVectorData.getName());
}
}
}
private OnHeapHnswGraph writeGraph(
RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
throws IOException {
// HNSW的构建器
HnswGraphBuilder hnswGraphBuilder =
new HnswGraphBuilder(
vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
// 构建HNSW
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
// 构建好的HNSW写入文件
int countOnLevel0 = graph.size();
for (int level = 0; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size();
vectorIndex.writeInt(size);
// Destructively modify; it's ok we are discarding it after this
int[] nnodes = neighbors.node();
Arrays.sort(nnodes, 0, size);
for (int i = 0; i < size; i++) {
int nnode = nnodes[i];
assert nnode < countOnLevel0 : "node too large: " + nnode + ">=" + countOnLevel0;
vectorIndex.writeInt(nnode);
}
// if number of connections < maxConn, add bogus values up to maxConn to have predictable
// offsets
for (int i = size; i < maxConn; i++) {
vectorIndex.writeInt(0);
}
}
}
return graph;
}
构建流程
// 每个节点的最大邻居个数
private final int maxConn;
// 构建查找节点邻居时,最多的候选邻居个数
// 然后再从这些候选邻居中根据启发式选择算法(HNSW论文中)选择出maxConn个真正的邻居
private final int beamWidth;
// 新增节点时需要一个随机函数为其生成该节点可以到达的最高层,ml是这个函数的标准化参数
private final double ml;
// 候选的邻居节点会按照距离由远到近排序存储在scratch,然后执行启发式选择算法从中选择真正的邻居
private final NeighborArray scratch;
// 距离度量
private final VectorSimilarityFunction similarityFunction;
// 在查找邻居的时候使用vectorValues获取指定的向量
private final RandomAccessVectorValues vectorValues;
// 在执行启发式邻居选择算法时通过buildVectors获取指定的向量。
// buildVectors底层的向量数据跟vectorValues是一样的。
// 虽然RandomAccessVectorValues不是线程安全的,但是当前的构建是单线,
// 所以理论上应该可以和vectorValues公用一个,有懂的伙伴欢迎讨论private RandomAccessVectorValues buildVectors;
// 随机生成层高的函数会用到
private final SplittableRandom random;
// 启发式选择算法用到
private final BoundsChecker bound;
// 构建的过程需要为新加入的节点查找近邻候选者
private final HnswGraphSearcher graphSearcher;
// 当前正在构建的HNSW图
final OnHeapHnswGraph hnsw;
public HnswGraphBuilder(
RandomAccessVectorValuesProducer vectors,
VectorSimilarityFunction similarityFunction,
int maxConn,
int beamWidth,
long seed) {
// 可见vectorValues和buildVectors底层是同一份向量数据
vectorValues = vectors.randomAccess();
buildVectors = vectors.randomAccess();
this.similarityFunction = Objects.requireNonNull(similarityFunction);
this.maxConn = maxConn;
this.beamWidth = beamWidth;
// 标准化参数和maxConn有关
this.ml = 1 / Math.log(1.0 * maxConn);
this.random = new SplittableRandom(seed);
// 前面介绍OnHeapHnswGraph的时候说过,levelOfFirstNode是随机生成的初始层号。
int levelOfFirstNode = getRandomGraphLevel(ml, random);
// 初始化HNSW结构,等待加入节点
this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode);
this.graphSearcher =
new HnswGraphSearcher(
similarityFunction,
new NeighborQueue(beamWidth, similarityFunction.reversed == false),
new FixedBitSet(vectorValues.size()));
bound = BoundsChecker.create(similarityFunction.reversed);
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
}
public OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
// vectors 和 vectorValues 底层也是同一份向量数据
// 之所以要区分,我理解是为了顺序IO。vectors是用来遍历所有的向量数据进行构建的,从头往后遍历,
// 而vectorValues是随机访问。如果有其他解释,欢迎讨论。
if (vectors == vectorValues) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
}
// 第0个节点默认就是第一个向量,在初始化HNSW的时候已经默认构建好了第一个向量,还记得那个levelOfFirstNode吗?
// 开始遍历节点,添加进HNSW(添加的意思是加入节点信息,也加入节点的邻居信息)
for (int node = 1; node < vectors.size(); node++) {
// 添加当前遍历的节点
addGraphNode(node, vectors.vectorValue(node));
}
return hnsw;
}
void addGraphNode(int node, float[] value) throws IOException {
NeighborQueue candidates;
// 获取当前处理节点的最高层号,随机产生,
final int nodeLevel = getRandomGraphLevel(ml, random);
// 当前HNSW的最高层号,层编号是从0开始的
int curMaxLevel = hnsw.numLevels() - 1;
// 起始遍历的节点,初始化为最顶层的entry point
int[] eps = new int[] {hnsw.entryNode()};
// 如果待加入节点的nodeLevel大于当前最高层,则在超出最高层的每一层中加入这个节点
for (int level = nodeLevel; level > curMaxLevel; level--) {
hnsw.addNode(level, node);
}
// 如果待加入节点的nodeLevel小于当前最高层,则从最高层开始在超出nodeLevel的每一层找一个最近点,然后作为下一层
// 的entry point,继续往下迭代
for (int level = curMaxLevel; level > nodeLevel; level--) {
candidates = graphSearcher.searchLevel(value, 1, level, eps, vectorValues, hnsw);
eps = new int[] {candidates.pop()};
}
// 在[0, min(nodeLevel, curMaxLevel)]中的每一层寻找最近的beamWidth个最近邻,
// 然后从中由启发式选择算法挑选maxCnn个作为真正的邻居
for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) {
// 在当前层中寻找beamWidth个候选邻居节点
candidates = graphSearcher.searchLevel(value, beamWidth, level, eps, vectorValues, hnsw);
// 当前层找到的候选邻居节点作为下一层的entry point
eps = candidates.nodes();
// 把节点加入当前层,并生成一个空的邻居列表,在下面addDiverseNeighbors中会填充这个空的邻居列表
hnsw.addNode(level, node);
// 用启发式选择算法从候选的邻居中选择maxConn个邻居
addDiverseNeighbors(level, node, candidates);
}
}
启发式选择算法
lucene官方测试报告显示了使用启发式选择算法替代最近邻邻居选择算法,召回和时延都有明显的提升,具体的指标可以查看:https://issues.apache.org/jira/browse/LUCENE-9644
private void addDiverseNeighbors(int level, int node, NeighborQueue candidates)
throws IOException {
// 获取当前节点空的邻居列表
NeighborArray neighbors = hnsw.getNeighbors(level, node);
// 对当前的候选者按照距离由远到近排序,暂存进scratch,底层是个数组,这个方法简单就不展开了
popToScratch(candidates);
// 按照启发式算法选择最多样(视觉效果是最发散)的邻居(多样的度量后面介绍)存入neighbors中
selectDiverse(neighbors, scratch);
// 邻居是互为邻居,所以也需要为邻居节点把当前节点作为邻居
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node[i];
NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
// 把当前节点作为邻居节点的邻居
nbrNbr.add(node, neighbors.score[i]);
// 因为当前节点的邻居可能因为增加当前节点这个邻居导致邻居总数超过maxConn,
// 所以需要调整已有的节点的邻居,满足最多邻居个数不超过maxConn
if (nbrNbr.size() > maxConn) {
// 如果超出了最大邻居个数,则按照多样性去掉一个
diversityUpdate(nbrNbr);
}
}
}
private void selectDiverse(NeighborArray neighbors, NeighborArray candidates) throws IOException {
// 从距离最近的候选节点开始遍历
for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
int cNode = candidates.node[i];
float cScore = candidates.score[i];
assert cNode < hnsw.size();
// 如果满足多样性的检查,就加入结果集
if (diversityCheck(vectorValues.vectorValue(cNode), cScore, neighbors, buildVectors)) {
neighbors.add(cNode, cScore);
}
}
}
private boolean diversityCheck(
float[] candidate,
float score, // 候选节点和当前加入节点的距离
NeighborArray neighbors,
RandomAccessVectorValues vectorValues)
throws IOException {
// 欧式距离为例
// 如果待校验的值比score(候选节点和当前加入节点的距离)小,则bound.check为false
bound.set(score);
// 遍历当前已经选出的邻居
for (int i = 0; i < neighbors.size(); i++) {
// 当前候选邻居和其他邻居之间的距离
float diversityCheck =
similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node[i]));
// 当前候选邻居和其他邻居之间的距离有一个比当前候选邻居和节点之间的距离小,则该候选邻居多样性不足,不能成为邻居。
// 也就是邻居之间的距离要足够远,可以想象到要找的邻居是遍布四周的,而不是集中在一堆,
// 这样查找的时候可以快速从该节点通过邻居出发遍历不同方向的节点
if (bound.check(diversityCheck) == false) {
return false;
}
}
return true;
}
private int findNonDiverse(NeighborArray neighbors) throws IOException {
// 遍历寻找第一个多样性违规的邻居
for (int i = neighbors.size() - 1; i >= 0; i--) {
int nbrNode = neighbors.node[i];
bound.set(neighbors.score[i]);
float[] nbrVector = vectorValues.vectorValue(nbrNode);
// 不用担心越界问题,能走到这里肯定是maxConn + 1个邻居
for (int j = maxConn; j > i; j--) {
float diversityCheck =
similarityFunction.compare(nbrVector, buildVectors.vectorValue(neighbors.node[j]));
// 如果节点i和节点j离的太近了,则返回该ndoe编号,用来删除
if (bound.check(diversityCheck) == false) {
return i;
}
}
}
// 不存在多样性违规的节点,则返回-1
return -1;
}
private void diversityUpdate(NeighborArray neighbors) throws IOException {
assert neighbors.size() == maxConn + 1;
// 寻找多样性违规的邻居
int replacePoint = findNonDiverse(neighbors);
// 如果没有多样性违规的节点
if (replacePoint == -1) {
// 如果新加入的邻居距离比第1个邻居远,则直接删除新加入的邻居,感觉有点粗暴
bound.set(neighbors.score[0]);
if (bound.check(neighbors.score[maxConn])) {
neighbors.removeLast();
return;
} else {
replacePoint = 0;
}
}
// 用新加入的邻居替换可以替换的邻居
neighbors.node[replacePoint] = neighbors.node[maxConn];
neighbors.score[replacePoint] = neighbors.score[maxConn];
neighbors.removeLast();
}
尾声
与论文算法的区别
可以讨论的问题
•数据删除导致的搜索问题
写在最后
感谢看到此处的看官,如有疏漏,欢迎指正讨论。