core

Fill in a module description here

source

ASTNode

 ASTNode (node, is_internal, is_builtin, node_types)

Initialize self. See help(type(self)) for accurate signature.


source

traverse

 traverse (node, results)

Traverse in a recursive way, a tree-sitter node and append results to a list.

Type Details
node tree-sitter node
results list to append results to
Returns None

source

get_token_type

 get_token_type (tok_span:tuple, nodes:list, lines:list,
                 internal_methods:list, acceptable_ast_types:list,
                 node_types:list)

Get the parent AST type and token AST type of a token.

Type Details
tok_span tuple (start, end) position of a token
nodes list list of tree-sitter nodes
lines list list of lines in the code
internal_methods list list of internal methods
acceptable_ast_types list list of AST types to accept for internal methods
node_types list list of node types
Returns tuple (parent_type, token_type) of the token

source

CodeTokenizer

 CodeTokenizer (tokenizer, parser, node_types, name_or_path, program_lang,
                padding_token)

A tokenizer for code, which aligns the tokens with the AST nodes.

Details
tokenizer transformers tokenizer
parser tree-sitter parser
node_types list of node types
name_or_path name or path of the tokenizer
program_lang programming language of the tokenizer
padding_token whether to add a padding token
# test the tokenizer
py_tokenizer = CodeTokenizer.from_pretrained("gpt2", "python")
code = "def foo():\n    print('hello world')"

encoding = py_tokenizer(code)

assert "ast_ids" in encoding
assert "parent_ast_ids" in encoding
assert "merged_ast" in encoding
assert len(encoding["ast_ids"]) == len(encoding["input_ids"])
assert len(encoding["parent_ast_ids"]) == len(encoding["input_ids"])
assert len(encoding["merged_ast"]) == len(encoding["input_ids"])
assert len(encoding["is_internal_methods"]) == len(encoding["input_ids"])
assert len(encoding["is_builtins"]) == len(encoding["input_ids"])
# test with list of code
code = ["def foo():\n    print('hello world')", "def bar():\n    print('hello world')"]
encoding = py_tokenizer(code)

assert "ast_ids" in encoding
assert "parent_ast_ids" in encoding
assert "merged_ast" in encoding
assert len(encoding["ast_ids"]) == len(encoding["input_ids"])
assert len(encoding["parent_ast_ids"]) == len(encoding["input_ids"])
assert len(encoding["merged_ast"]) == len(encoding["input_ids"])
assert len(encoding["is_internal_methods"]) == len(encoding["input_ids"])
assert len(encoding["is_builtins"]) == len(encoding["input_ids"])
assert len(encoding["ast_ids"][0]) == len(encoding["input_ids"][0])
assert len(encoding["parent_ast_ids"][0]) == len(encoding["input_ids"][0])
assert len(encoding["merged_ast"][0]) == len(encoding["input_ids"][0])
assert len(encoding["is_internal_methods"][0]) == len(encoding["input_ids"][0])
assert len(encoding["is_builtins"][0]) == len(encoding["input_ids"][0])
# test with internal methods
code = "def print():\n    print('print') #print\n    print = 1"
encoding = py_tokenizer(code, internal_methods=["print"])

for i in range(len(encoding["input_ids"])):
    if (
        "call" in encoding["merged_ast"][i]
        or "argument_list" in encoding["merged_ast"][i]
    ):
        assert encoding["is_internal_methods"][i] == True, encoding["merged_ast"][i]
    else:
        assert encoding["is_internal_methods"][i] == False, encoding["merged_ast"][i]
# test with internal methods and batched
code = "def foo():\n    print('print') #print"
encoding = py_tokenizer([code] * 2, internal_methods=[["print"], ["print"]])

for i in range(len(encoding["input_ids"])):
    for j in range(len(encoding["input_ids"][i])):
        if (
            "call" in encoding["merged_ast"][i][j]
            or "argument_list" in encoding["merged_ast"][i][j]
        ):
            assert encoding["is_internal_methods"][i][j] == True, encoding[
                "merged_ast"
            ][i][j]
        else:
            assert encoding["is_internal_methods"][i][j] == False, encoding[
                "merged_ast"
            ][i][j]
# test without internal methods
code = "def foo():\n    print('print') #print"
encoding = py_tokenizer(code)

for i in range(len(encoding["input_ids"])):
    assert encoding["is_internal_methods"][i] == False
# test without internal methods and batched
code = "def foo():\n    print('print') #print"
encoding = py_tokenizer([code] * 2)

for i in range(len(encoding["input_ids"])):
    for j in range(len(encoding["input_ids"][i])):
        assert encoding["is_internal_methods"][i][j] == False
# test with builtins
code = "def foo():\n    print('print') #print\n    print = 1"
encoding = py_tokenizer(code)

for i in range(len(encoding["input_ids"])):
    if "call" in encoding["merged_ast"][i]:
        assert encoding["is_builtins"][i] == True, encoding["merged_ast"][i]
    else:
        assert encoding["is_builtins"][i] == False, encoding["merged_ast"][i]
# test with builtins and batched
code = "def foo():\n    print('print') #print"
encoding = py_tokenizer([code] * 2)

for i in range(len(encoding["input_ids"])):
    for j in range(len(encoding["input_ids"][i])):
        if "call" in encoding["merged_ast"][i][j]:
            assert encoding["is_builtins"][i][j] == True, encoding["merged_ast"][i][j]
        else:
            assert encoding["is_builtins"][i][j] == False, encoding["merged_ast"][i][j]
# test the pickleability of the tokenizer
import pickle

assert py_tokenizer == pickle.loads(pickle.dumps(py_tokenizer))
# test the time of multi-proc tokenization is faster than single proc tokenization
import time
from datasets import load_dataset

ds = load_dataset("codeparrot/codeparrot-clean-valid", split="train").select(range(10))

start = time.time()
single_proc_ds = ds.map(
    lambda x: py_tokenizer(x["content"]),
    batched=False,
    batch_size=1,
    num_proc=1,
    load_from_cache_file=False,
)
total_single_proc = time.time() - start

start = time.time()
multi_proc_ds = ds.map(
    lambda x: py_tokenizer(x["content"]),
    batched=False,
    batch_size=1,
    num_proc=4,
    load_from_cache_file=False,
)
total_multi_proc = time.time() - start

assert total_multi_proc < total_single_proc
Using custom data configuration codeparrot--codeparrot-clean-valid-826c6fd8b27e5523
Found cached dataset json (/work/.cache/huggingface/datasets/codeparrot___json/codeparrot--codeparrot-clean-valid-826c6fd8b27e5523/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)
Token indices sequence length is longer than the specified maximum sequence length for this model (1185 > 1024). Running this sequence through the model will result in indexing errors
        
# test that the two datasets tokenized with single and multi processing are identical

for i in range(len(ds)):
    assert single_proc_ds[i]["input_ids"] == multi_proc_ds[i]["input_ids"]
    assert single_proc_ds[i]["attention_mask"] == multi_proc_ds[i]["attention_mask"]
    assert single_proc_ds[i]["offset_mapping"] == multi_proc_ds[i]["offset_mapping"]
    assert single_proc_ds[i]["ast_ids"] == multi_proc_ds[i]["ast_ids"]
    assert single_proc_ds[i]["parent_ast_ids"] == multi_proc_ds[i]["parent_ast_ids"]
    assert single_proc_ds[i]["merged_ast"] == multi_proc_ds[i]["merged_ast"]
    assert (
        single_proc_ds[i]["is_internal_methods"]
        == multi_proc_ds[i]["is_internal_methods"]
    )
    assert single_proc_ds[i]["is_builtins"] == multi_proc_ds[i]["is_builtins"]