Skip to content

Commit 9341e3b

Browse files
committed
Update torch.export test file (#1211)
1 parent ca11c34 commit 9341e3b

File tree

3 files changed

+203
-41
lines changed

3 files changed

+203
-41
lines changed

source/python.js

Lines changed: 145 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7595,16 +7595,26 @@ python.Execution = class {
75957595
});
75967596
this.registerType('sympy.core.relational.GreaterThan', class extends sympy.core.relational._Greater {
75977597
constructor(lhs, rhs) {
7598-
super(lhs, rhs, '>');
7598+
super(lhs, rhs, '>=');
75997599
}
76007600
});
76017601
this.registerType('sympy.core.relational._Less', class extends sympy.core.relational._Inequality {
76027602
});
76037603
this.registerType('sympy.core.relational.LessThan', class extends sympy.core.relational.Relational {
7604+
constructor(lhs, rhs) {
7605+
super(lhs, rhs, '<=');
7606+
}
7607+
});
7608+
this.registerType('sympy.core.relational.StrictLessThan', class extends sympy.core.relational.Relational {
76047609
constructor(lhs, rhs) {
76057610
super(lhs, rhs, '<');
76067611
}
76077612
});
7613+
this.registerType('sympy.core.relational.StrictGreaterThan', class extends sympy.core.relational.Relational {
7614+
constructor(lhs, rhs) {
7615+
super(lhs, rhs, '>');
7616+
}
7617+
});
76087618
this.registerType('sympy.core.relational.Equality', class extends sympy.core.relational.Relational {
76097619
constructor(lhs, rhs) {
76107620
super(lhs, rhs, '==');
@@ -7632,7 +7642,9 @@ python.Execution = class {
76327642
case 'Max': return new sympy.functions.elementary.miscellaneous.Max(...node.args.map((arg) => sympify(arg)));
76337643
case 'Integer': return new sympy.core.numbers.Integer(node.args[0].value);
76347644
case 'GreaterThan': return new sympy.core.relational.GreaterThan(sympify(node.args[0]), sympify(node.args[1]));
7645+
case 'StrictGreaterThan': return new sympy.core.relational.StrictGreaterThan(sympify(node.args[0]), sympify(node.args[1]));
76357646
case 'LessThan': return new sympy.core.relational.LessThan(sympify(node.args[0]), sympify(node.args[1]));
7647+
case 'StrictLessThan': return new sympy.core.relational.StrictLessThan(sympify(node.args[0]), sympify(node.args[1]));
76367648
case 'Equality': return new sympy.core.relational.Equality(sympify(node.args[0]), sympify(node.args[1]));
76377649
default: throw new python.Error(`Unsupported SymPy function '${node.func.id}'.`);
76387650
}
@@ -7652,15 +7664,22 @@ python.Execution = class {
76527664
if (node.op instanceof ast.Pow) {
76537665
return new sympy.core.power.Pow(sympify(node.left), sympify(node.right));
76547666
}
7667+
throw new python.Error(`Unsupported SymPy BinOp op '${node.op.__class__.__name__}'.`);
76557668
}
76567669
if (node instanceof ast.Compare) {
76577670
const left = sympify(node.left);
76587671
const right = sympify(node.comparators[0]);
76597672
const [op] = node.ops;
76607673
if (op instanceof ast.Gt) {
7674+
return new sympy.core.relational.StrictGreaterThan(left, right);
7675+
}
7676+
if (op instanceof ast.GtE) {
76617677
return new sympy.core.relational.GreaterThan(left, right);
76627678
}
76637679
if (op instanceof ast.Lt) {
7680+
return new sympy.core.relational.StrictLessThan(left, right);
7681+
}
7682+
if (op instanceof ast.LtE) {
76647683
return new sympy.core.relational.LessThan(left, right);
76657684
}
76667685
if (op instanceof ast.Eq) {
@@ -18575,7 +18594,12 @@ python.Execution = class {
1857518594
COMPLEXFLOAT: 10,
1857618595
COMPLEXDOUBLE: 11,
1857718596
BOOL: 12,
18578-
BFLOAT16: 13
18597+
BFLOAT16: 13,
18598+
UINT16: 28,
18599+
FLOAT8E4M3FN: 29,
18600+
FLOAT8E5M2: 30,
18601+
FLOAT8E4M3FNUZ: 31,
18602+
FLOAT8E5M2FNUZ: 32,
1857918603
};
1858018604
torch._export.serde.schema.Layout = {
1858118605
Unknown: 0,
@@ -18875,16 +18899,121 @@ python.Execution = class {
1887518899
}
1887618900
}
1887718901
});
18902+
this.registerFunction('torch.export.pt2_archive._package._load_state_dict', (f, model_name) => {
18903+
const legacy_file = `data/weights/${model_name}.pt`;
18904+
if (f.has(legacy_file)) {
18905+
return f.get(legacy_file);
18906+
}
18907+
const weights_config_file = `data/weights/${model_name}_weights_config.json`;
18908+
if (!f.has(weights_config_file)) {
18909+
return null;
18910+
}
18911+
const weights_config = f.get(weights_config_file);
18912+
const state_dict_file_map = torch.export.pt2_archive._package._build_file_map(f, weights_config, 'data/weights/');
18913+
const state_dict = new builtins.dict();
18914+
for (const [weight_fqn, payload_meta] of Object.entries(weights_config.config)) {
18915+
if (payload_meta.use_pickle) {
18916+
const weight_bytes = f.get(`data/weights/${payload_meta.path_name}`);
18917+
const weight_tensor = torch.load(weight_bytes);
18918+
state_dict.set(weight_fqn, weight_tensor);
18919+
} else {
18920+
const tensor_meta = payload_meta.tensor_meta;
18921+
const tensor = state_dict_file_map.get(payload_meta.path_name);
18922+
const sizes = tensor_meta.sizes.map((s) => s.as_int);
18923+
const strides = tensor_meta.strides.map((s) => s.as_int);
18924+
const storage_offset = tensor_meta.storage_offset.as_int;
18925+
const weight_tensor = new torch.Tensor();
18926+
weight_tensor.__setstate__([tensor.storage(), storage_offset, sizes, strides]);
18927+
weight_tensor.requires_grad = tensor_meta.requires_grad || false;
18928+
if (payload_meta.is_param) {
18929+
state_dict.set(weight_fqn, new torch.nn.parameter.Parameter(weight_tensor, tensor_meta.requires_grad));
18930+
} else {
18931+
state_dict.set(weight_fqn, weight_tensor);
18932+
}
18933+
}
18934+
}
18935+
return state_dict;
18936+
});
18937+
this.registerFunction('torch.export.pt2_archive._package._load_constants', (f, model_name) => {
18938+
const legacy_file = `data/constants/${model_name}.pt`;
18939+
if (f.has(legacy_file)) {
18940+
const entries = f.get(legacy_file);
18941+
return new builtins.dict(entries);
18942+
}
18943+
const constants_config_file = `data/constants/${model_name}_constants_config.json`;
18944+
if (!f.has(constants_config_file)) {
18945+
return null;
18946+
}
18947+
const constants_config = f.get(constants_config_file);
18948+
const constant_file_map = torch.export.pt2_archive._package._build_file_map(f, constants_config, 'data/constants/');
18949+
const constants = new builtins.dict();
18950+
for (const [constant_fqn, payload_meta] of Object.entries(constants_config.config)) {
18951+
const path_name = payload_meta.path_name;
18952+
if (path_name.startsWith('tensor_')) {
18953+
if (payload_meta.use_pickle) {
18954+
const constant_bytes = f.get(`data/constants/${payload_meta.path_name}`);
18955+
const constant_tensor = torch.load(constant_bytes);
18956+
constants.set(constant_fqn, constant_tensor);
18957+
} else {
18958+
const tensor_meta = payload_meta.tensor_meta;
18959+
const tensor = constant_file_map.get(payload_meta.path_name);
18960+
const sizes = tensor_meta.sizes.map((s) => s.as_int);
18961+
const strides = tensor_meta.strides.map((s) => s.as_int);
18962+
const storage_offset = tensor_meta.storage_offset.as_int;
18963+
const constant_tensor = new torch.Tensor();
18964+
constant_tensor.__setstate__([tensor.storage(), storage_offset, sizes, strides]);
18965+
constants.set(constant_fqn, constant_tensor);
18966+
}
18967+
} else if (payload_meta.path_name.startsWith('custom_obj_')) {
18968+
const custom_obj_bytes = f.get(`data/constants/${payload_meta.path_name}`);
18969+
const custom_obj = torch._C._pickle_load_obj(custom_obj_bytes);
18970+
constants.set(constant_fqn, custom_obj);
18971+
}
18972+
}
18973+
return constants;
18974+
});
18975+
this.registerFunction('torch._export.serde.serialize.deserialize_scalar_type', (st) => {
18976+
if (!torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE.has(st)) {
18977+
throw new python.Error(`Unsupported scalar type '${st}'.`);
18978+
}
18979+
return torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE.get(st);
18980+
});
18981+
this.registerFunction('torch.export.pt2_archive._package._build_file_map', (archive_reader, config, base_dir) => {
18982+
const file_map = new builtins.dict();
18983+
for (const payload_meta of Object.values(config.config)) {
18984+
if (payload_meta.use_pickle) {
18985+
continue;
18986+
}
18987+
if (file_map.has(payload_meta.path_name)) {
18988+
continue;
18989+
}
18990+
const tensor_bytes = archive_reader.get(`${base_dir}${payload_meta.path_name}`);
18991+
const tensor = torch.export.pt2_archive._package._create_flat_tensor_from_bytes(tensor_bytes, payload_meta.tensor_meta);
18992+
file_map.set(payload_meta.path_name, tensor);
18993+
}
18994+
return file_map;
18995+
});
18996+
this.registerFunction('torch.export.pt2_archive._package._create_flat_tensor_from_bytes', (tensor_bytes, tensor_meta) => {
18997+
const dtype = torch._export.serde.serialize.deserialize_scalar_type(tensor_meta.dtype);
18998+
const itemsize = dtype.itemsize();
18999+
const num_elements = tensor_bytes.length / itemsize;
19000+
const storage = new torch.storage.TypedStorage(num_elements, dtype);
19001+
storage._set_cdata(tensor_bytes);
19002+
const tensor = new torch.Tensor();
19003+
tensor.__setstate__([storage, 0, [num_elements], [1]]);
19004+
tensor.requires_grad = tensor_meta.requires_grad || false;
19005+
return tensor;
19006+
});
1887819007
this.registerFunction('torch.export.pt2_archive._package.load_pt2', (f, expected_opset_version) => {
1887919008
const exported_programs = new Map();
1888019009
for (const name of f.keys()) {
1888119010
const match = name.match(/^models\/([^/]+)\.json$/);
1888219011
if (match) {
1888319012
const [, model_name] = match;
1888419013
const serialized_exported_program = f.get(`models/${model_name}.json`);
18885-
const serialized_state_dict = f.get(`data/weights/${model_name}.pt`);
18886-
const serialized_constants = f.get(`data/constants/${model_name}.pt`);
18887-
const serialized_example_inputs = f.get(`data/sample_inputs/${model_name}.pt`);
19014+
const serialized_state_dict = torch.export.pt2_archive._package._load_state_dict(f, model_name);
19015+
const serialized_constants = torch.export.pt2_archive._package._load_constants(f, model_name);
19016+
const serialized_example_inputs = f.get(`data/sample_inputs/${model_name}.pt`, 'zip');
1888819017
const artifact = new torch._export.serde.serialize.SerializedArtifact(serialized_exported_program, serialized_state_dict, serialized_constants, serialized_example_inputs);
1888919018
const exported_program = torch._export.serde.serialize.deserialize(artifact, expected_opset_version);
1889019019
exported_programs.set(model_name, exported_program);
@@ -18942,7 +19071,10 @@ python.Execution = class {
1894219071
}
1894319072
});
1894419073
this.registerFunction('torch._export.serde.serialize.deserialize_torch_artifact', (serialized) => {
18945-
if (!serialized) {
19074+
if (serialized instanceof builtins.dict || serialized instanceof builtins.tuple) {
19075+
return serialized;
19076+
}
19077+
if (serialized === null || serialized.length === 0) {
1894619078
return new builtins.dict();
1894719079
}
1894819080
const artifact = torch.load(serialized);
@@ -19217,8 +19349,8 @@ python.Execution = class {
1921719349
this.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)
1921819350
*/
1921919351
this.example_inputs = null;
19220-
if (example_inputs && example_inputs.length > 0) {
19221-
torch._export.serde.serialize.deserialize_torch_artifact(example_inputs);
19352+
if (example_inputs) {
19353+
this.example_inputs = torch._export.serde.serialize.deserialize_torch_artifact(example_inputs);
1922219354
}
1922319355
this.deserialize_graph(serialized_graph_module.graph);
1922419356
const module_call_graph = null; // this.deserialize_module_call_graph(serialized_graph_module.module_call_graph)
@@ -19265,7 +19397,7 @@ python.Execution = class {
1926519397
} else if (typ_ === 'as_tensor') {
1926619398
return this.serialized_name_to_node.get(inp.as_tensor.name);
1926719399
} else if (typ_ === 'as_scalar_type') {
19268-
return torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type];
19400+
return torch._export.serde.serialize.deserialize_scalar_type(inp.as_scalar_type);
1926919401
} else if (typ_ === 'as_memory_format') {
1927019402
return torch._export.serde.serialize._SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format];
1927119403
} else if (typ_ === 'as_layout') {
@@ -19499,7 +19631,7 @@ python.Execution = class {
1949919631
const sizes = tensor_meta.sizes.map((val) => this.deserialize_sym_int(val));
1950019632
const strides = tensor_meta.strides.map((val) => this.deserialize_sym_int(val));
1950119633
const device = this.deserialize_device(tensor_meta.device);
19502-
const dtype = torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype];
19634+
const dtype = torch._export.serde.serialize.deserialize_scalar_type(tensor_meta.dtype);
1950319635
return torch.empty_strided(sizes, strides, dtype, null, device);
1950419636
} finally {
1950519637
this.fake_tensor_mode.__exit__(null, null, null);
@@ -20350,13 +20482,13 @@ python.Execution = class {
2035020482
torch.uint16 = new torch.dtype(27, 'uint16', 2);
2035120483
torch.uint32 = new torch.dtype(28, 'uint32', 4);
2035220484
torch.uint64 = new torch.dtype(29, 'uint64', 8);
20353-
torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE = Object.fromEntries([
20485+
torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE = new Map([
2035420486
['uint8', 'BYTE'],
2035520487
['int8', 'CHAR'], ['int16', 'SHORT'], ['int32', 'INT'], ['int64', 'LONG'],
2035620488
['float16', 'HALF'], ['float32', 'FLOAT'], ['float64', 'DOUBLE'],
2035720489
['complex32', 'COMPLEXHALF'], ['complex64', 'COMPLEXFLOAT'], ['complex128', 'COMPLEXDOUBLE'],
20358-
['bool', 'BOOL'],
20359-
['bfloat16', 'BFLOAT16']
20490+
['bool', 'BOOL'], ['bfloat16', 'BFLOAT16'], ['uint16', 'UINT16'],
20491+
['float8_e4m3fn','FLOAT8E4M3FN'], ['float8_e5m2','FLOAT8E5M2'], ['float8_e4m3fnuz','FLOAT8E4M3FNUZ'], ['float8_e5m2fnuz','FLOAT8E5M2FNUZ']
2036020492
].map(([key, value]) => [torch._export.serde.schema.ScalarType[value], torch[key]]));
2036120493
torch.contiguous_format = new torch.memory_format('contiguous_format');
2036220494
torch.channels_last = new torch.memory_format('channels_last');

0 commit comments

Comments
 (0)