@@ -7595,16 +7595,26 @@ python.Execution = class {
75957595 });
75967596 this.registerType('sympy.core.relational.GreaterThan', class extends sympy.core.relational._Greater {
75977597 constructor(lhs, rhs) {
7598- super(lhs, rhs, '>');
7598+ super(lhs, rhs, '>= ');
75997599 }
76007600 });
76017601 this.registerType('sympy.core.relational._Less', class extends sympy.core.relational._Inequality {
76027602 });
76037603 this.registerType('sympy.core.relational.LessThan', class extends sympy.core.relational.Relational {
7604+ constructor(lhs, rhs) {
7605+ super(lhs, rhs, '<=');
7606+ }
7607+ });
7608+ this.registerType('sympy.core.relational.StrictLessThan', class extends sympy.core.relational.Relational {
76047609 constructor(lhs, rhs) {
76057610 super(lhs, rhs, '<');
76067611 }
76077612 });
7613+ this.registerType('sympy.core.relational.StrictGreaterThan', class extends sympy.core.relational.Relational {
7614+ constructor(lhs, rhs) {
7615+ super(lhs, rhs, '>');
7616+ }
7617+ });
76087618 this.registerType('sympy.core.relational.Equality', class extends sympy.core.relational.Relational {
76097619 constructor(lhs, rhs) {
76107620 super(lhs, rhs, '==');
@@ -7632,7 +7642,9 @@ python.Execution = class {
76327642 case 'Max': return new sympy.functions.elementary.miscellaneous.Max(...node.args.map((arg) => sympify(arg)));
76337643 case 'Integer': return new sympy.core.numbers.Integer(node.args[0].value);
76347644 case 'GreaterThan': return new sympy.core.relational.GreaterThan(sympify(node.args[0]), sympify(node.args[1]));
7645+ case 'StrictGreaterThan': return new sympy.core.relational.StrictGreaterThan(sympify(node.args[0]), sympify(node.args[1]));
76357646 case 'LessThan': return new sympy.core.relational.LessThan(sympify(node.args[0]), sympify(node.args[1]));
7647+ case 'StrictLessThan': return new sympy.core.relational.StrictLessThan(sympify(node.args[0]), sympify(node.args[1]));
76367648 case 'Equality': return new sympy.core.relational.Equality(sympify(node.args[0]), sympify(node.args[1]));
76377649 default: throw new python.Error(`Unsupported SymPy function '${node.func.id}'.`);
76387650 }
@@ -7652,15 +7664,22 @@ python.Execution = class {
76527664 if (node.op instanceof ast.Pow) {
76537665 return new sympy.core.power.Pow(sympify(node.left), sympify(node.right));
76547666 }
7667+ throw new python.Error(`Unsupported SymPy BinOp op '${node.op.__class__.__name__}'.`);
76557668 }
76567669 if (node instanceof ast.Compare) {
76577670 const left = sympify(node.left);
76587671 const right = sympify(node.comparators[0]);
76597672 const [op] = node.ops;
76607673 if (op instanceof ast.Gt) {
7674+ return new sympy.core.relational.StrictGreaterThan(left, right);
7675+ }
7676+ if (op instanceof ast.GtE) {
76617677 return new sympy.core.relational.GreaterThan(left, right);
76627678 }
76637679 if (op instanceof ast.Lt) {
7680+ return new sympy.core.relational.StrictLessThan(left, right);
7681+ }
7682+ if (op instanceof ast.LtE) {
76647683 return new sympy.core.relational.LessThan(left, right);
76657684 }
76667685 if (op instanceof ast.Eq) {
@@ -18575,7 +18594,12 @@ python.Execution = class {
1857518594 COMPLEXFLOAT: 10,
1857618595 COMPLEXDOUBLE: 11,
1857718596 BOOL: 12,
18578- BFLOAT16: 13
18597+ BFLOAT16: 13,
18598+ UINT16: 28,
18599+ FLOAT8E4M3FN: 29,
18600+ FLOAT8E5M2: 30,
18601+ FLOAT8E4M3FNUZ: 31,
18602+ FLOAT8E5M2FNUZ: 32,
1857918603 };
1858018604 torch._export.serde.schema.Layout = {
1858118605 Unknown: 0,
@@ -18875,16 +18899,121 @@ python.Execution = class {
1887518899 }
1887618900 }
1887718901 });
18902+ this.registerFunction('torch.export.pt2_archive._package._load_state_dict', (f, model_name) => {
18903+ const legacy_file = `data/weights/${model_name}.pt`;
18904+ if (f.has(legacy_file)) {
18905+ return f.get(legacy_file);
18906+ }
18907+ const weights_config_file = `data/weights/${model_name}_weights_config.json`;
18908+ if (!f.has(weights_config_file)) {
18909+ return null;
18910+ }
18911+ const weights_config = f.get(weights_config_file);
18912+ const state_dict_file_map = torch.export.pt2_archive._package._build_file_map(f, weights_config, 'data/weights/');
18913+ const state_dict = new builtins.dict();
18914+ for (const [weight_fqn, payload_meta] of Object.entries(weights_config.config)) {
18915+ if (payload_meta.use_pickle) {
18916+ const weight_bytes = f.get(`data/weights/${payload_meta.path_name}`);
18917+ const weight_tensor = torch.load(weight_bytes);
18918+ state_dict.set(weight_fqn, weight_tensor);
18919+ } else {
18920+ const tensor_meta = payload_meta.tensor_meta;
18921+ const tensor = state_dict_file_map.get(payload_meta.path_name);
18922+ const sizes = tensor_meta.sizes.map((s) => s.as_int);
18923+ const strides = tensor_meta.strides.map((s) => s.as_int);
18924+ const storage_offset = tensor_meta.storage_offset.as_int;
18925+ const weight_tensor = new torch.Tensor();
18926+ weight_tensor.__setstate__([tensor.storage(), storage_offset, sizes, strides]);
18927+ weight_tensor.requires_grad = tensor_meta.requires_grad || false;
18928+ if (payload_meta.is_param) {
18929+ state_dict.set(weight_fqn, new torch.nn.parameter.Parameter(weight_tensor, tensor_meta.requires_grad));
18930+ } else {
18931+ state_dict.set(weight_fqn, weight_tensor);
18932+ }
18933+ }
18934+ }
18935+ return state_dict;
18936+ });
18937+ this.registerFunction('torch.export.pt2_archive._package._load_constants', (f, model_name) => {
18938+ const legacy_file = `data/constants/${model_name}.pt`;
18939+ if (f.has(legacy_file)) {
18940+ const entries = f.get(legacy_file);
18941+ return new builtins.dict(entries);
18942+ }
18943+ const constants_config_file = `data/constants/${model_name}_constants_config.json`;
18944+ if (!f.has(constants_config_file)) {
18945+ return null;
18946+ }
18947+ const constants_config = f.get(constants_config_file);
18948+ const constant_file_map = torch.export.pt2_archive._package._build_file_map(f, constants_config, 'data/constants/');
18949+ const constants = new builtins.dict();
18950+ for (const [constant_fqn, payload_meta] of Object.entries(constants_config.config)) {
18951+ const path_name = payload_meta.path_name;
18952+ if (path_name.startsWith('tensor_')) {
18953+ if (payload_meta.use_pickle) {
18954+ const constant_bytes = f.get(`data/constants/${payload_meta.path_name}`);
18955+ const constant_tensor = torch.load(constant_bytes);
18956+ constants.set(constant_fqn, constant_tensor);
18957+ } else {
18958+ const tensor_meta = payload_meta.tensor_meta;
18959+ const tensor = constant_file_map.get(payload_meta.path_name);
18960+ const sizes = tensor_meta.sizes.map((s) => s.as_int);
18961+ const strides = tensor_meta.strides.map((s) => s.as_int);
18962+ const storage_offset = tensor_meta.storage_offset.as_int;
18963+ const constant_tensor = new torch.Tensor();
18964+ constant_tensor.__setstate__([tensor.storage(), storage_offset, sizes, strides]);
18965+ constants.set(constant_fqn, constant_tensor);
18966+ }
18967+ } else if (payload_meta.path_name.startsWith('custom_obj_')) {
18968+ const custom_obj_bytes = f.get(`data/constants/${payload_meta.path_name}`);
18969+ const custom_obj = torch._C._pickle_load_obj(custom_obj_bytes);
18970+ constants.set(constant_fqn, custom_obj);
18971+ }
18972+ }
18973+ return constants;
18974+ });
18975+ this.registerFunction('torch._export.serde.serialize.deserialize_scalar_type', (st) => {
18976+ if (!torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE.has(st)) {
18977+ throw new python.Error(`Unsupported scalar type '${st}'.`);
18978+ }
18979+ return torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE.get(st);
18980+ });
18981+ this.registerFunction('torch.export.pt2_archive._package._build_file_map', (archive_reader, config, base_dir) => {
18982+ const file_map = new builtins.dict();
18983+ for (const payload_meta of Object.values(config.config)) {
18984+ if (payload_meta.use_pickle) {
18985+ continue;
18986+ }
18987+ if (file_map.has(payload_meta.path_name)) {
18988+ continue;
18989+ }
18990+ const tensor_bytes = archive_reader.get(`${base_dir}${payload_meta.path_name}`);
18991+ const tensor = torch.export.pt2_archive._package._create_flat_tensor_from_bytes(tensor_bytes, payload_meta.tensor_meta);
18992+ file_map.set(payload_meta.path_name, tensor);
18993+ }
18994+ return file_map;
18995+ });
18996+ this.registerFunction('torch.export.pt2_archive._package._create_flat_tensor_from_bytes', (tensor_bytes, tensor_meta) => {
18997+ const dtype = torch._export.serde.serialize.deserialize_scalar_type(tensor_meta.dtype);
18998+ const itemsize = dtype.itemsize();
18999+ const num_elements = tensor_bytes.length / itemsize;
19000+ const storage = new torch.storage.TypedStorage(num_elements, dtype);
19001+ storage._set_cdata(tensor_bytes);
19002+ const tensor = new torch.Tensor();
19003+ tensor.__setstate__([storage, 0, [num_elements], [1]]);
19004+ tensor.requires_grad = tensor_meta.requires_grad || false;
19005+ return tensor;
19006+ });
1887819007 this.registerFunction('torch.export.pt2_archive._package.load_pt2', (f, expected_opset_version) => {
1887919008 const exported_programs = new Map();
1888019009 for (const name of f.keys()) {
1888119010 const match = name.match(/^models\/([^/]+)\.json$/);
1888219011 if (match) {
1888319012 const [, model_name] = match;
1888419013 const serialized_exported_program = f.get(`models/${model_name}.json`);
18885- const serialized_state_dict = f.get(`data/weights/${ model_name}.pt` );
18886- const serialized_constants = f.get(`data/constants/${ model_name}.pt` );
18887- const serialized_example_inputs = f.get(`data/sample_inputs/${model_name}.pt`);
19014+ const serialized_state_dict = torch.export.pt2_archive._package._load_state_dict(f, model_name);
19015+ const serialized_constants = torch.export.pt2_archive._package._load_constants(f, model_name);
19016+ const serialized_example_inputs = f.get(`data/sample_inputs/${model_name}.pt`, 'zip' );
1888819017 const artifact = new torch._export.serde.serialize.SerializedArtifact(serialized_exported_program, serialized_state_dict, serialized_constants, serialized_example_inputs);
1888919018 const exported_program = torch._export.serde.serialize.deserialize(artifact, expected_opset_version);
1889019019 exported_programs.set(model_name, exported_program);
@@ -18942,7 +19071,10 @@ python.Execution = class {
1894219071 }
1894319072 });
1894419073 this.registerFunction('torch._export.serde.serialize.deserialize_torch_artifact', (serialized) => {
18945- if (!serialized) {
19074+ if (serialized instanceof builtins.dict || serialized instanceof builtins.tuple) {
19075+ return serialized;
19076+ }
19077+ if (serialized === null || serialized.length === 0) {
1894619078 return new builtins.dict();
1894719079 }
1894819080 const artifact = torch.load(serialized);
@@ -19217,8 +19349,8 @@ python.Execution = class {
1921719349 this.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)
1921819350 */
1921919351 this.example_inputs = null;
19220- if (example_inputs && example_inputs.length > 0 ) {
19221- torch._export.serde.serialize.deserialize_torch_artifact(example_inputs);
19352+ if (example_inputs) {
19353+ this.example_inputs = torch._export.serde.serialize.deserialize_torch_artifact(example_inputs);
1922219354 }
1922319355 this.deserialize_graph(serialized_graph_module.graph);
1922419356 const module_call_graph = null; // this.deserialize_module_call_graph(serialized_graph_module.module_call_graph)
@@ -19265,7 +19397,7 @@ python.Execution = class {
1926519397 } else if (typ_ === 'as_tensor') {
1926619398 return this.serialized_name_to_node.get(inp.as_tensor.name);
1926719399 } else if (typ_ === 'as_scalar_type') {
19268- return torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE[ inp.as_scalar_type] ;
19400+ return torch._export.serde.serialize.deserialize_scalar_type( inp.as_scalar_type) ;
1926919401 } else if (typ_ === 'as_memory_format') {
1927019402 return torch._export.serde.serialize._SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format];
1927119403 } else if (typ_ === 'as_layout') {
@@ -19499,7 +19631,7 @@ python.Execution = class {
1949919631 const sizes = tensor_meta.sizes.map((val) => this.deserialize_sym_int(val));
1950019632 const strides = tensor_meta.strides.map((val) => this.deserialize_sym_int(val));
1950119633 const device = this.deserialize_device(tensor_meta.device);
19502- const dtype = torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE[ tensor_meta.dtype] ;
19634+ const dtype = torch._export.serde.serialize.deserialize_scalar_type( tensor_meta.dtype) ;
1950319635 return torch.empty_strided(sizes, strides, dtype, null, device);
1950419636 } finally {
1950519637 this.fake_tensor_mode.__exit__(null, null, null);
@@ -20350,13 +20482,13 @@ python.Execution = class {
2035020482 torch.uint16 = new torch.dtype(27, 'uint16', 2);
2035120483 torch.uint32 = new torch.dtype(28, 'uint32', 4);
2035220484 torch.uint64 = new torch.dtype(29, 'uint64', 8);
20353- torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE = Object.fromEntries ([
20485+ torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE = new Map ([
2035420486 ['uint8', 'BYTE'],
2035520487 ['int8', 'CHAR'], ['int16', 'SHORT'], ['int32', 'INT'], ['int64', 'LONG'],
2035620488 ['float16', 'HALF'], ['float32', 'FLOAT'], ['float64', 'DOUBLE'],
2035720489 ['complex32', 'COMPLEXHALF'], ['complex64', 'COMPLEXFLOAT'], ['complex128', 'COMPLEXDOUBLE'],
20358- ['bool', 'BOOL'],
20359- ['bfloat16', 'BFLOAT16 ']
20490+ ['bool', 'BOOL'], ['bfloat16', 'BFLOAT16'], ['uint16', 'UINT16'],
20491+ ['float8_e4m3fn','FLOAT8E4M3FN'], ['float8_e5m2','FLOAT8E5M2'], ['float8_e4m3fnuz','FLOAT8E4M3FNUZ'], ['float8_e5m2fnuz','FLOAT8E5M2FNUZ ']
2036020492 ].map(([key, value]) => [torch._export.serde.schema.ScalarType[value], torch[key]]));
2036120493 torch.contiguous_format = new torch.memory_format('contiguous_format');
2036220494 torch.channels_last = new torch.memory_format('channels_last');
0 commit comments