Shorkie, a sequence-to-expression model with masked-language-modelling as pretraining came out sometime last year (2025). I always wanted to do a proper port from one framework to another, so here we go!

The first observation is that the real challenge will not be the difference between the frameworks but that Shorkie was built using calico’s frameworks baskerville and westminster that essentially automate a lot of the model “assembly” and training.

The official training script, make_model.sh , calls westminster_train_folds.py from the westminster repository. This in turn calls https://github.com/calico/baskerville-yeast/blob/main/src/baskerville/scripts/hound_train.py from https://github.com/calico/baskerville-yeast. In there, we find line 181: seqnn_model = seqnn.SeqNN(params_model) which load the model params from the config file (https://storage.googleapis.com/seqnn-share/shorkie/params.json). For Shorkie, the model params are this:

"model": {
        "seq_length": 16384,
        "augment_rc": false,
        "augment_shift": 0,
        "activation": "gelu",
        "norm_type": "batch",
        "bn_momentum": 0.9,
        "kernel_initializer": "lecun_normal",
        "l2_scale": 1.0e-6,
        "trunk": [
            {
                "name": "conv_dna",
                "filters": 96,
                "kernel_size": 11,
                "norm_type": null,
                "activation": "linear"
            },
            {
                "name": "res_tower",
                "filters_init": 96,
                "filters_end": 384,
                "divisible_by": 32,
                "kernel_size": 5,
                "num_convs": 2,
                "dropout": 0.05,
                "pool_size": 2,
                "repeat": 7
            },
            {
                "name": "transformer_tower",
                "key_size": 64,
                "heads": 4,
                "num_position_features": 32,
                "dropout": 0.2,
                "mha_l2_scale": 1.0e-8,
                "l2_scale": 1.0e-8,
                "kernel_initializer": "he_normal",
                "repeat": 8
            },
            {
                "name": "unet_conv",
                "kernel_size": 3,
                "upsample_conv": true
            },
            {
                "name": "unet_conv",
                "kernel_size": 3,
                "upsample_conv": true
            },
            {
                "name": "unet_conv",
                "kernel_size": 3,
                "upsample_conv": true
            },
            {
                "name": "Cropping1D",
                "cropping": 64
            }
        ],
        "head": {
            "name": "final",
            "units": 5215,
            "activation": "softplus"
        }
    }

Via the import we get to SeqNN in baskerville-yeast (line 27, from baskerville import seqnn) in src/baskerville/seqnn.py. It’s init function just calls build_model() with the parameters:

class SeqNN:
    """Sequence neural network model.

    Args:
      params (dict): Model specification and parameters.
    """

    def __init__(self, params: dict):
        self.set_defaults()
        for key, value in params.items():
            self.__setattr__(key, value)
        self.build_model()
        self.ensemble = None

The build_model function has four parts: inputs, convolution blocks, heads and “compilation”.

def build_model(self, save_reprs: bool = True):
        """Build the model."""

        ###################################################
        # inputs
        sequence = tf.keras.Input(shape=(self.seq_length, self.num_features), name="sequence")
        current = sequence

        # augmentation
        if self.augment_rc:
            current, reverse_bool = layers.StochasticReverseComplement()(current)
        if self.augment_shift != [0]:
            current = layers.StochasticShift(self.augment_shift)(current)
        self.preds_triu = False

        ###################################################
        # build convolution blocks
        self.reprs = []
        for bi, block_params in enumerate(self.trunk):
            current = self.build_block(current, block_params)
            if save_reprs:
                self.reprs.append(current)

        # final activation
        current = layers.activate(current, self.activation)

        # make model trunk
        trunk_output = current
        self.model_trunk = tf.keras.Model(inputs=sequence, outputs=trunk_output)

        ###################################################
        # heads
        head_keys = natsorted([v for v in vars(self) if v.startswith("head")])
        self.heads = [getattr(self, hk) for hk in head_keys]

        self.head_output = []
        for hi, head in enumerate(self.heads):
            if not isinstance(head, list):
                head = [head]

            # reset to trunk output
            current = trunk_output

            # build blocks
            for bi, block_params in enumerate(head):
                current = self.build_block(current, block_params)

            if hi < len(self.strand_pair):
                strand_pair = self.strand_pair[hi]
            else:
                strand_pair = None

            # transform back from reverse complement
            if self.augment_rc:
                if self.preds_triu:
                    current = layers.SwitchReverseTriu(self.diagonal_offset)(
                        [current, reverse_bool]
                    )
                else:
                    current = layers.SwitchReverse(strand_pair)([current, reverse_bool])

            # save head output
            self.head_output.append(current)

        ###################################################
        # compile model(s)
        self.models = []
        for ho in self.head_output:
            self.models.append(tf.keras.Model(inputs=sequence, outputs=ho))
        self.model = self.models[0]
        if self.verbose:
            print(self.model.summary())

        # track pooling/striding and cropping
        self.track_sequence(sequence)

We can skip right to building the trunk where we use this part of the config:

"trunk": [
            {
                "name": "conv_dna",
                "filters": 96,
                "kernel_size": 11,
                "norm_type": null,
                "activation": "linear"
            },
            {
                "name": "res_tower",
                "filters_init": 96,
                "filters_end": 384,
                "divisible_by": 32,
                "kernel_size": 5,
                "num_convs": 2,
                "dropout": 0.05,
                "pool_size": 2,
                "repeat": 7
            },
            {
                "name": "transformer_tower",
                "key_size": 64,
                "heads": 4,
                "num_position_features": 32,
                "dropout": 0.2,
                "mha_l2_scale": 1.0e-8,
                "l2_scale": 1.0e-8,
                "kernel_initializer": "he_normal",
                "repeat": 8
            },
            {
                "name": "unet_conv",
                "kernel_size": 3,
                "upsample_conv": true
            },
            {
                "name": "unet_conv",
                "kernel_size": 3,
                "upsample_conv": true
            },
            {
                "name": "unet_conv",
                "kernel_size": 3,
                "upsample_conv": true
            },
            {
                "name": "Cropping1D",
                "cropping": 64
            }
        ],

We loop over this config, aggregating the modules:

for bi, block_params in enumerate(self.trunk):
            current = self.build_block(current, block_params)

This is the same as nn.Sequential with the correct blocks that we still have to implement.