Skip to content

Commit

Permalink
Fix memory zones
Browse files Browse the repository at this point in the history
  • Loading branch information
honnibal committed Sep 9, 2024
1 parent 59ac7e6 commit a019315
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 71 deletions.
2 changes: 1 addition & 1 deletion spacy/lang/kmr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .lex_attrs import LEX_ATTRS
from ...language import BaseDefaults, Language
from .lex_attrs import LEX_ATTRS
from .stop_words import STOP_WORDS


Expand Down
1 change: 0 additions & 1 deletion spacy/lang/kmr/lex_attrs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from ...attrs import LIKE_NUM


_num_words = [
"sifir",
"yek",
Expand Down
35 changes: 34 additions & 1 deletion spacy/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import random
import traceback
import warnings
from contextlib import contextmanager
from contextlib import ExitStack, contextmanager
from copy import deepcopy
from dataclasses import dataclass
from itertools import chain, cycle
Expand All @@ -31,6 +31,7 @@
)

import srsly
from cymem.cymem import Pool
from thinc.api import Config, CupyOps, Optimizer, get_current_ops

from . import about, ty, util
Expand Down Expand Up @@ -2091,6 +2092,38 @@ def replace_listeners(
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
tok2vec.remove_listener(listener, pipe_name)

@contextmanager
def memory_zone(self, mem: Optional[Pool]=None) -> Iterator[Pool]:
"""Begin a block where all resources allocated during the block will
be freed at the end of it. If a resources was created within the
memory zone block, accessing it outside the block is invalid.
Behaviour of this invalid access is undefined. Memory zones should
not be nested.
The memory zone is helpful for services that need to process large
volumes of text with a defined memory budget.
Example
-------
>>> with nlp.memory_zone():
... for doc in nlp.pipe(texts):
... process_my_doc(doc)
>>> # use_doc(doc) <-- Invalid: doc was allocated in the memory zone
"""
if mem is None:
mem = Pool()
# The ExitStack allows programmatic nested context managers.
# We don't know how many we need, so it would be awkward to have
# them as nested blocks.
with ExitStack() as stack:
contexts = [stack.enter_context(self.vocab.memory_zone(mem))]
if hasattr(self.tokenizer, "memory_zone"):
contexts.append(stack.enter_context(self.tokenizer.memory_zone(mem)))
for _, pipe in self.pipeline:
if hasattr(pipe, "memory_zone"):
contexts.append(stack.enter_context(pipe.memory_zone(mem)))
yield mem

def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion spacy/pipeline/_parser_internals/arc_eager.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ cdef class ArcEagerGold:
def __init__(self, ArcEager moves, StateClass stcls, Example example):
self.mem = Pool()
heads, labels = example.get_aligned_parse(projectivize=True)
labels = [example.x.vocab.strings.add(label) if label is not None else MISSING_DEP for label in labels]
labels = [example.x.vocab.strings.add(label, allow_transient=False) if label is not None else MISSING_DEP for label in labels]
sent_starts = _get_aligned_sent_starts(example)
assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts))
self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts)
Expand Down
2 changes: 1 addition & 1 deletion spacy/pipeline/_parser_internals/nonproj.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ cpdef deprojectivize(Doc doc):
new_label, head_label = label.split(DELIMITER)
new_head = _find_new_head(doc[i], head_label)
doc.c[i].head = new_head.i - i
doc.c[i].dep = doc.vocab.strings.add(new_label)
doc.c[i].dep = doc.vocab.strings.add(new_label, allow_transient=False)
set_children_from_heads(doc.c, 0, doc.length)
return doc

Expand Down
1 change: 0 additions & 1 deletion spacy/strings.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,4 @@ cdef class StringStore:
cdef const Utf8Str* intern_unicode(self, str py_string, bint allow_transient)
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash, bint allow_transient)
cdef vector[hash_t] _transient_keys
cdef PreshMap _transient_map
cdef Pool _non_temp_mem
83 changes: 28 additions & 55 deletions spacy/strings.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ from typing import Iterator, List, Optional
from libc.stdint cimport uint32_t
from libc.string cimport memcpy
from murmurhash.mrmr cimport hash32, hash64
from preshed.maps cimport map_clear

import srsly

Expand Down Expand Up @@ -125,10 +126,9 @@ cdef class StringStore:
self.mem = Pool()
self._non_temp_mem = self.mem
self._map = PreshMap()
self._transient_map = None
if strings is not None:
for string in strings:
self.add(string)
self.add(string, allow_transient=False)

def __getitem__(self, object string_or_id):
"""Retrieve a string from a given hash, or vice versa.
Expand Down Expand Up @@ -158,17 +158,17 @@ cdef class StringStore:
return SYMBOLS_BY_INT[str_hash]
else:
utf8str = <Utf8Str*>self._map.get(str_hash)
if utf8str is NULL and self._transient_map is not None:
utf8str = <Utf8Str*>self._transient_map.get(str_hash)
if utf8str is NULL:
raise KeyError(Errors.E018.format(hash_value=string_or_id))
else:
return decode_Utf8Str(utf8str)
else:
# TODO: Raise an error instead
utf8str = <Utf8Str*>self._map.get(string_or_id)
if utf8str is NULL and self._transient_map is not None:
utf8str = <Utf8Str*>self._transient_map.get(str_hash)
if utf8str is NULL:
raise KeyError(Errors.E018.format(hash_value=string_or_id))
else:
return decode_Utf8Str(utf8str)
if utf8str is NULL:
raise KeyError(Errors.E018.format(hash_value=string_or_id))
else:
return decode_Utf8Str(utf8str)

def as_int(self, key):
"""If key is an int, return it; otherwise, get the int value."""
Expand All @@ -184,16 +184,12 @@ cdef class StringStore:
else:
return self[key]

def __reduce__(self):
strings = list(self.non_transient_keys())
return (StringStore, (strings,), None, None, None)

def __len__(self) -> int:
"""The number of strings in the store.

RETURNS (int): The number of strings in the store.
"""
return self._keys.size() + self._transient_keys.size()
return self.keys.size() + self._transient_keys.size()
@contextmanager
def memory_zone(self, mem: Optional[Pool] = None) -> Pool:
Expand All @@ -209,13 +205,13 @@ cdef class StringStore:
if mem is None:
mem = Pool()
self.mem = mem
self._transient_map = PreshMap()
yield mem
self.mem = self._non_temp_mem
self._transient_map = None
for key in self._transient_keys:
map_clear(self._map.c_map, key)
self._transient_keys.clear()
self.mem = self._non_temp_mem
def add(self, string: str, allow_transient: bool = False) -> int:
def add(self, string: str, allow_transient: Optional[bool] = None) -> int:
"""Add a string to the StringStore.

string (str): The string to add.
Expand All @@ -226,6 +222,8 @@ cdef class StringStore:
internally should not.
RETURNS (uint64): The string's hash value.
"""
if allow_transient is None:
allow_transient = self.mem is not self._non_temp_mem
cdef hash_t str_hash
if isinstance(string, str):
if string in SYMBOLS_BY_STR:
Expand Down Expand Up @@ -273,17 +271,13 @@ cdef class StringStore:
# TODO: Raise an error instead
if self._map.get(string_or_id) is not NULL:
return True
elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
return True
else:
return False
if str_hash < len(SYMBOLS_BY_INT):
return True
else:
if self._map.get(str_hash) is not NULL:
return True
elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
return True
else:
return False
Expand All @@ -292,32 +286,21 @@ cdef class StringStore:

YIELDS (str): A string in the store.
"""
yield from self.non_transient_keys()
yield from self.transient_keys()
def non_transient_keys(self) -> Iterator[str]:
"""Iterate over the stored strings in insertion order.

RETURNS: A list of strings.
"""
cdef int i
cdef hash_t key
for i in range(self.keys.size()):
key = self.keys[i]
utf8str = <Utf8Str*>self._map.get(key)
yield decode_Utf8Str(utf8str)
for i in range(self._transient_keys.size()):
key = self._transient_keys[i]
utf8str = <Utf8Str*>self._map.get(key)
yield decode_Utf8Str(utf8str)
def __reduce__(self):
strings = list(self)
return (StringStore, (strings,), None, None, None)
def transient_keys(self) -> Iterator[str]:
if self._transient_map is None:
return []
for i in range(self._transient_keys.size()):
utf8str = <Utf8Str*>self._transient_map.get(self._transient_keys[i])
yield decode_Utf8Str(utf8str)
def values(self) -> List[int]:
"""Iterate over the stored strings hashes in insertion order.

Expand All @@ -327,12 +310,9 @@ cdef class StringStore:
hashes = [None] * self._keys.size()
for i in range(self._keys.size()):
hashes[i] = self._keys[i]
if self._transient_map is not None:
transient_hashes = [None] * self._transient_keys.size()
for i in range(self._transient_keys.size()):
transient_hashes[i] = self._transient_keys[i]
else:
transient_hashes = []
transient_hashes = [None] * self._transient_keys.size()
for i in range(self._transient_keys.size()):
transient_hashes[i] = self._transient_keys[i]
return hashes + transient_hashes
def to_disk(self, path):
Expand Down Expand Up @@ -383,8 +363,10 @@ cdef class StringStore:
def _reset_and_load(self, strings):
self.mem = Pool()
self._non_temp_mem = self.mem
self._map = PreshMap()
self.keys.clear()
self._transient_keys.clear()
for string in strings:
self.add(string, allow_transient=False)
Expand All @@ -401,19 +383,10 @@ cdef class StringStore:
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
if value is not NULL:
return value
if allow_transient and self._transient_map is not None:
# If we've already allocated a transient string, and now we
# want to intern it permanently, we'll end up with the string
# in both places. That seems fine -- I don't see why we need
# to remove it from the transient map.
value = <Utf8Str*>self._transient_map.get(key)
if value is not NULL:
return value
value = _allocate(self.mem, <unsigned char*>utf8_string, length)
if allow_transient and self._transient_map is not None:
self._transient_map.set(key, value)
self._map.set(key, value)
if allow_transient and self.mem is not self._non_temp_mem:
self._transient_keys.push_back(key)
else:
self._map.set(key, value)
self.keys.push_back(key)
return value
8 changes: 2 additions & 6 deletions spacy/tokenizer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,8 @@ cdef class Tokenizer:
if n <= 0:
# avoid mem alloc of zero length
return 0
# Historically this check was mostly used to avoid caching
# chunks that had tokens owned by the Doc. Now that that's
# not a thing, I don't think we need this?
for i in range(n):
if self.vocab._by_orth.get(tokens[i].lex.orth) == NULL:
return 0
if self.vocab.in_memory_zone:
return 0
# See #1250
if has_special[0]:
return 0
Expand Down
13 changes: 9 additions & 4 deletions spacy/vocab.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ from typing import Iterator, Optional
import numpy
import srsly
from thinc.api import get_array_module, get_current_ops
from preshed.maps cimport map_clear

from .attrs cimport LANG, ORTH
from .lexeme cimport EMPTY_LEXEME, OOV_RANK, Lexeme
Expand Down Expand Up @@ -104,7 +105,7 @@ cdef class Vocab:
def vectors(self, vectors):
if hasattr(vectors, "strings"):
for s in vectors.strings:
self.strings.add(s)
self.strings.add(s, allow_transient=False)
self._vectors = vectors
self._vectors.strings = self.strings

Expand All @@ -115,6 +116,10 @@ cdef class Vocab:
langfunc = self.lex_attr_getters.get(LANG, None)
return langfunc("_") if langfunc else ""

@property
def in_memory_zone(self) -> bool:
return self.mem is not self._non_temp_mem

def __len__(self):
"""The current number of lexemes stored.
Expand Down Expand Up @@ -218,7 +223,7 @@ cdef class Vocab:
# this size heuristic.
mem = self.mem
lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC))
lex.orth = self.strings.add(string)
lex.orth = self.strings.add(string, allow_transient=True)
lex.length = len(string)
if self.vectors is not None and hasattr(self.vectors, "key2row"):
lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK)
Expand All @@ -239,13 +244,13 @@ cdef class Vocab:
cdef int _add_lex_to_vocab(self, hash_t key, const LexemeC* lex, bint is_transient) except -1:
self._by_orth.set(lex.orth, <void*>lex)
self.length += 1
if is_transient:
if is_transient and self.in_memory_zone:
self._transient_orths.push_back(lex.orth)

def _clear_transient_orths(self):
"""Remove transient lexemes from the index (generally at the end of the memory zone)"""
for orth in self._transient_orths:
self._by_orth.pop(orth)
map_clear(self._by_orth.c_map, orth)
self._transient_orths.clear()

def __contains__(self, key):
Expand Down

0 comments on commit a019315

Please sign in to comment.