Skip to content

Commit 44aa3be

Browse files
committed
Add common grammars and json-schema-to-grammar utility function from llama.cpp
1 parent 305482b commit 44aa3be

File tree

1 file changed

+309
-6
lines changed

1 file changed

+309
-6
lines changed

llama_cpp/llama_grammar.py

Lines changed: 309 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
"""C++ implementation of the llama grammar parser."""
1+
"""Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp."""
2+
23
# flake8: noqa
34
from pathlib import Path
45
import sys
@@ -1056,8 +1057,7 @@ def print_rule(
10561057
# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
10571058
if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END:
10581059
raise RuntimeError(
1059-
"malformed rule, does not end with LLAMA_GRETYPE_END: "
1060-
+ str(rule_id)
1060+
"malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id)
10611061
)
10621062
print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ")
10631063
# for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
@@ -1102,9 +1102,7 @@ def print_rule(
11021102
for i, elem in enumerate(rule[:-1]):
11031103
case = elem.type # type: llama_gretype
11041104
if case is llama_gretype.LLAMA_GRETYPE_END:
1105-
raise RuntimeError(
1106-
"unexpected end of rule: " + str(rule_id) + "," + str(i)
1107-
)
1105+
raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i))
11081106
elif case is llama_gretype.LLAMA_GRETYPE_ALT:
11091107
print("| ", file=file, end="")
11101108
elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF:
@@ -1186,3 +1184,308 @@ def print_grammar(file: TextIO, state: parse_state) -> None:
11861184
f"{print_grammar.__name__}: error printing grammar: {err}",
11871185
file=sys.stderr,
11881186
)
1187+
1188+
1189+
"""llama.cpp gbnf rules from vendor/llama.cpp/grammars"""
1190+
1191+
ARITHMETIC_GBNF = """\
1192+
root ::= (expr "=" ws term "\n")+
1193+
expr ::= term ([-+*/] term)*
1194+
term ::= ident | num | "(" ws expr ")" ws
1195+
ident ::= [a-z] [a-z0-9_]* ws
1196+
num ::= [0-9]+ ws
1197+
ws ::= [ \t\n]*
1198+
"""
1199+
1200+
C_GBNF = """\
1201+
root ::= (declaration)*
1202+
1203+
declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}"
1204+
1205+
dataType ::= "int" ws | "float" ws | "char" ws
1206+
identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*
1207+
1208+
parameter ::= dataType identifier
1209+
1210+
statement ::=
1211+
( dataType identifier ws "=" ws expression ";" ) |
1212+
( identifier ws "=" ws expression ";" ) |
1213+
( identifier ws "(" argList? ")" ";" ) |
1214+
( "return" ws expression ";" ) |
1215+
( "while" "(" condition ")" "{" statement* "}" ) |
1216+
( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) |
1217+
( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) |
1218+
( singleLineComment ) |
1219+
( multiLineComment )
1220+
1221+
forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression
1222+
forUpdate ::= identifier ws "=" ws expression
1223+
1224+
condition ::= expression relationOperator expression
1225+
relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">")
1226+
1227+
expression ::= term (("+" | "-") term)*
1228+
term ::= factor(("*" | "/") factor)*
1229+
1230+
factor ::= identifier | number | unaryTerm | funcCall | parenExpression
1231+
unaryTerm ::= "-" factor
1232+
funcCall ::= identifier "(" argList? ")"
1233+
parenExpression ::= "(" ws expression ws ")"
1234+
1235+
argList ::= expression ("," ws expression)*
1236+
1237+
number ::= [0-9]+
1238+
1239+
singleLineComment ::= "//" [^\n]* "\n"
1240+
multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/"
1241+
1242+
ws ::= ([ \t\n]+)
1243+
"""
1244+
1245+
CHESS_GBNF = """\
1246+
root ::= object
1247+
value ::= object | array | string | number | ("true" | "false" | "null") ws
1248+
1249+
object ::=
1250+
"{" ws (
1251+
string ":" ws value
1252+
("," ws string ":" ws value)*
1253+
)? "}" ws
1254+
1255+
array ::=
1256+
"[" ws (
1257+
value
1258+
("," ws value)*
1259+
)? "]" ws
1260+
1261+
string ::=
1262+
"\"" (
1263+
[^"\\] |
1264+
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
1265+
)* "\"" ws
1266+
1267+
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
1268+
1269+
# Optional space: by convention, applied in this grammar after literal chars when allowed
1270+
ws ::= ([ \t\n] ws)?
1271+
"""
1272+
1273+
JAPANESE_GBNF = """\
1274+
root ::= object
1275+
value ::= object | array | string | number | ("true" | "false" | "null") ws
1276+
1277+
object ::=
1278+
"{" ws (
1279+
string ":" ws value
1280+
("," ws string ":" ws value)*
1281+
)? "}" ws
1282+
1283+
array ::=
1284+
"[" ws (
1285+
value
1286+
("," ws value)*
1287+
)? "]" ws
1288+
1289+
string ::=
1290+
"\"" (
1291+
[^"\\] |
1292+
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
1293+
)* "\"" ws
1294+
1295+
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
1296+
1297+
# Optional space: by convention, applied in this grammar after literal chars when allowed
1298+
ws ::= ([ \t\n] ws)?
1299+
"""
1300+
1301+
JSON_ARR_GBNF = """\
1302+
# This is the same as json.gbnf but we restrict whitespaces at the end of the root array
1303+
# Useful for generating JSON arrays
1304+
1305+
root ::= arr
1306+
value ::= object | array | string | number | ("true" | "false" | "null") ws
1307+
1308+
arr ::=
1309+
"[\n" ws (
1310+
value
1311+
(",\n" ws value)*
1312+
)? "]"
1313+
1314+
object ::=
1315+
"{" ws (
1316+
string ":" ws value
1317+
("," ws string ":" ws value)*
1318+
)? "}" ws
1319+
1320+
array ::=
1321+
"[" ws (
1322+
value
1323+
("," ws value)*
1324+
)? "]" ws
1325+
1326+
string ::=
1327+
"\"" (
1328+
[^"\\] |
1329+
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
1330+
)* "\"" ws
1331+
1332+
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
1333+
1334+
# Optional space: by convention, applied in this grammar after literal chars when allowed
1335+
ws ::= ([ \t\n] ws)?
1336+
"""
1337+
1338+
1339+
JSON_GBNF = """\
1340+
root ::= object
1341+
value ::= object | array | string | number | ("true" | "false" | "null") ws
1342+
1343+
object ::=
1344+
"{" ws (
1345+
string ":" ws value
1346+
("," ws string ":" ws value)*
1347+
)? "}" ws
1348+
1349+
array ::=
1350+
"[" ws (
1351+
value
1352+
("," ws value)*
1353+
)? "]" ws
1354+
1355+
string ::=
1356+
"\"" (
1357+
[^"\\] |
1358+
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
1359+
)* "\"" ws
1360+
1361+
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
1362+
1363+
# Optional space: by convention, applied in this grammar after literal chars when allowed
1364+
ws ::= ([ \t\n] ws)?"""
1365+
1366+
LIST_GBNF = """\
1367+
root ::= item+
1368+
1369+
# Excludes various line break characters
1370+
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
1371+
"""
1372+
1373+
"""llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py"""
1374+
import json
1375+
import re
1376+
from typing import List, Optional
1377+
1378+
# whitespace is constrained to a single space char to prevent model "running away" in
1379+
# whitespace. Also maybe improves generation quality?
1380+
SPACE_RULE = '" "?'
1381+
1382+
PRIMITIVE_RULES = {
1383+
"boolean": '("true" | "false") space',
1384+
"number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
1385+
"integer": '("-"? ([0-9] | [1-9] [0-9]*)) space',
1386+
"string": r""" "\"" (
1387+
[^"\\] |
1388+
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
1389+
)* "\"" space """,
1390+
"null": '"null" space',
1391+
}
1392+
1393+
INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
1394+
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
1395+
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
1396+
1397+
1398+
class SchemaConverter:
1399+
def __init__(self, prop_order):
1400+
self._prop_order = prop_order
1401+
self._rules = {"space": SPACE_RULE}
1402+
1403+
def _format_literal(self, literal):
1404+
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
1405+
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal)
1406+
)
1407+
return f'"{escaped}"'
1408+
1409+
def _add_rule(self, name, rule):
1410+
esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
1411+
if esc_name not in self._rules or self._rules[esc_name] == rule:
1412+
key = esc_name
1413+
else:
1414+
i = 0
1415+
while f"{esc_name}{i}" in self._rules:
1416+
i += 1
1417+
key = f"{esc_name}{i}"
1418+
self._rules[key] = rule
1419+
return key
1420+
1421+
def visit(self, schema, name):
1422+
schema_type = schema.get("type")
1423+
rule_name = name or "root"
1424+
1425+
if "oneOf" in schema or "anyOf" in schema:
1426+
rule = " | ".join(
1427+
(
1428+
self.visit(alt_schema, f'{name}{"-" if name else ""}{i}')
1429+
for i, alt_schema in enumerate(
1430+
schema.get("oneOf") or schema["anyOf"]
1431+
)
1432+
)
1433+
)
1434+
return self._add_rule(rule_name, rule)
1435+
1436+
elif "const" in schema:
1437+
return self._add_rule(rule_name, self._format_literal(schema["const"]))
1438+
1439+
elif "enum" in schema:
1440+
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
1441+
return self._add_rule(rule_name, rule)
1442+
1443+
elif schema_type == "object" and "properties" in schema:
1444+
# TODO: `required` keyword
1445+
prop_order = self._prop_order
1446+
prop_pairs = sorted(
1447+
schema["properties"].items(),
1448+
# sort by position in prop_order (if specified) then by key
1449+
key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]),
1450+
)
1451+
1452+
rule = '"{" space'
1453+
for i, (prop_name, prop_schema) in enumerate(prop_pairs):
1454+
prop_rule_name = self.visit(
1455+
prop_schema, f'{name}{"-" if name else ""}{prop_name}'
1456+
)
1457+
if i > 0:
1458+
rule += ' "," space'
1459+
rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}'
1460+
rule += ' "}" space'
1461+
1462+
return self._add_rule(rule_name, rule)
1463+
1464+
elif schema_type == "array" and "items" in schema:
1465+
# TODO `prefixItems` keyword
1466+
item_rule_name = self.visit(
1467+
schema["items"], f'{name}{"-" if name else ""}item'
1468+
)
1469+
rule = (
1470+
f'"[" space ({item_rule_name} ("," space {item_rule_name})*)? "]" space'
1471+
)
1472+
return self._add_rule(rule_name, rule)
1473+
1474+
else:
1475+
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
1476+
return self._add_rule(
1477+
"root" if rule_name == "root" else schema_type,
1478+
PRIMITIVE_RULES[schema_type],
1479+
)
1480+
1481+
def format_grammar(self):
1482+
return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items()))
1483+
1484+
1485+
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
1486+
prop_order = prop_order or []
1487+
schema = json.load(schema)
1488+
prop_order = {name: idx for idx, name in enumerate(prop_order)}
1489+
converter = SchemaConverter(prop_order)
1490+
converter.visit(schema, "")
1491+
return converter.format_grammar()

0 commit comments

Comments
 (0)