@@ -13,6 +13,11 @@ def draw(self, *args, **kwargs):
1313 def __call__ (self , * args , ** kwargs ):
1414 self .draw (* args , ** kwargs )
1515
16+ def set_output_mode (self , mode : str ):
17+ """Set notebook or script mode - not implemented yet"""
18+ ...
19+
20+
1621
1722class LossSubplot (BaseSubplot ):
1823 """To rewrire, this one now won't work"""
@@ -59,6 +64,7 @@ def draw(self, logs):
5964 plt .title (self .title )
6065 plt .xlabel ('epoch' )
6166 plt .legend (loc = 'center right' )
67+ plt .show ()
6268
6369
6470class Plot1D (BaseSubplot ):
@@ -77,6 +83,7 @@ def draw(self, *args, **kwargs):
7783 plt .plot (self .X , self .predict (self .model , self .X ), '-' , label = "Model" )
7884 plt .title ("Prediction" )
7985 plt .legend (loc = 'lower right' )
86+ plt .show ()
8087
8188
8289class Plot2d (BaseSubplot ):
@@ -119,3 +126,4 @@ def send(self, logger):
119126 plt .scatter (self .X [:, 0 ], self .X [:, 1 ], c = self .Y , cmap = self .cm_points )
120127 if self .X_test is not None :
121128 plt .scatter (self .X_test [:, 0 ], self .X_test [:, 1 ], c = self .Y_test , cmap = self .cm_points , alpha = 0.3 )
129+ plt .show ()
0 commit comments