001/*
002 * SPDX-FileCopyrightText: none
003 * SPDX-License-Identifier: CC0-1.0
004 */
005
006package gov.nist.secauto.oscal.lib.model.util;
007
008import gov.nist.secauto.metaschema.core.metapath.DynamicContext;
009import gov.nist.secauto.metaschema.core.metapath.StaticContext;
010import gov.nist.secauto.metaschema.core.metapath.item.ISequence;
011import gov.nist.secauto.metaschema.core.metapath.item.node.AbstractRecursionPreventingNodeItemVisitor;
012import gov.nist.secauto.metaschema.core.metapath.item.node.IAssemblyInstanceGroupedNodeItem;
013import gov.nist.secauto.metaschema.core.metapath.item.node.IAssemblyNodeItem;
014import gov.nist.secauto.metaschema.core.metapath.item.node.IDefinitionNodeItem;
015import gov.nist.secauto.metaschema.core.metapath.item.node.IFieldNodeItem;
016import gov.nist.secauto.metaschema.core.metapath.item.node.IFlagNodeItem;
017import gov.nist.secauto.metaschema.core.metapath.item.node.IModuleNodeItem;
018import gov.nist.secauto.metaschema.core.metapath.item.node.INodeItemFactory;
019import gov.nist.secauto.metaschema.core.model.IModule;
020import gov.nist.secauto.metaschema.core.model.constraint.IAllowedValuesConstraint;
021import gov.nist.secauto.metaschema.core.model.constraint.ILet;
022
023import java.util.Collection;
024import java.util.LinkedHashMap;
025import java.util.LinkedList;
026import java.util.List;
027import java.util.Map;
028
029import edu.umd.cs.findbugs.annotations.NonNull;
030
031public class AllowedValueCollectingNodeItemVisitor
032    extends AbstractRecursionPreventingNodeItemVisitor<DynamicContext, Void> {
033
034  private final Map<IDefinitionNodeItem<?, ?>, NodeItemRecord> nodeItemAnalysis = new LinkedHashMap<>();
035
036  public Collection<NodeItemRecord> getAllowedValueLocations() {
037    return nodeItemAnalysis.values();
038  }
039
040  public void visit(@NonNull IModule module) {
041    DynamicContext context = new DynamicContext(
042        StaticContext.builder()
043            .defaultModelNamespace(module.getXmlNamespace())
044            .build());
045    context.disablePredicateEvaluation();
046
047    visit(INodeItemFactory.instance().newModuleNodeItem(module), context);
048  }
049
050  public void visit(@NonNull IModuleNodeItem module, @NonNull DynamicContext context) {
051
052    visitMetaschema(module, context);
053  }
054
055  private void handleAllowedValuesAtLocation(
056      @NonNull IDefinitionNodeItem<?, ?> itemLocation,
057      @NonNull DynamicContext context) {
058    itemLocation.getDefinition().getAllowedValuesConstraints().stream()
059        .forEachOrdered(allowedValues -> {
060          ISequence<?> result = allowedValues.getTarget().evaluate(itemLocation, context);
061          result.stream().forEachOrdered(target -> {
062            assert target != null;
063            handleAllowedValues(allowedValues, itemLocation, (IDefinitionNodeItem<?, ?>) target);
064          });
065        });
066  }
067
068  private void handleAllowedValues(
069      @NonNull IAllowedValuesConstraint allowedValues,
070      @NonNull IDefinitionNodeItem<?, ?> location,
071      @NonNull IDefinitionNodeItem<?, ?> target) {
072    NodeItemRecord itemRecord = nodeItemAnalysis.get(target);
073    if (itemRecord == null) {
074      itemRecord = new NodeItemRecord(target);
075      nodeItemAnalysis.put(target, itemRecord);
076    }
077
078    AllowedValuesRecord allowedValuesRecord = new AllowedValuesRecord(allowedValues, location, target);
079    itemRecord.addAllowedValues(allowedValuesRecord);
080  }
081
082  @Override
083  public Void visitFlag(IFlagNodeItem item, DynamicContext context) {
084    assert context != null;
085    DynamicContext subContext = handleLetStatements(item, context);
086    handleAllowedValuesAtLocation(item, subContext);
087    return super.visitFlag(item, subContext);
088  }
089
090  @Override
091  public Void visitField(IFieldNodeItem item, DynamicContext context) {
092    assert context != null;
093    DynamicContext subContext = handleLetStatements(item, context);
094    handleAllowedValuesAtLocation(item, subContext);
095    return super.visitField(item, subContext);
096  }
097
098  @Override
099  public Void visitAssembly(IAssemblyNodeItem item, DynamicContext context) {
100    assert context != null;
101    DynamicContext subContext = handleLetStatements(item, context);
102    handleAllowedValuesAtLocation(item, subContext);
103    return super.visitAssembly(item, subContext);
104  }
105
106  private DynamicContext handleLetStatements(IDefinitionNodeItem<?, ?> item, DynamicContext context) {
107    assert context != null;
108    DynamicContext subContext = context;
109    for (ILet let : item.getDefinition().getLetExpressions().values()) {
110      ISequence<?> result = let.getValueExpression().evaluate(item,
111          subContext).reusable();
112      subContext = subContext.bindVariableValue(let.getName(), result);
113    }
114    return subContext;
115  }
116
117  @Override
118  public Void visitAssembly(IAssemblyInstanceGroupedNodeItem item, DynamicContext context) {
119    return visitAssembly((IAssemblyNodeItem) item, context);
120  }
121
122  @Override
123  protected Void defaultResult() {
124    return null;
125  }
126
127  public static final class NodeItemRecord {
128    @NonNull
129    private final IDefinitionNodeItem<?, ?> item;
130    @NonNull
131    private final List<AllowedValuesRecord> allowedValues = new LinkedList<>();
132
133    private NodeItemRecord(@NonNull IDefinitionNodeItem<?, ?> item) {
134      this.item = item;
135    }
136
137    @NonNull
138    public IDefinitionNodeItem<?, ?> getItem() {
139      return item;
140    }
141
142    @NonNull
143    public List<AllowedValuesRecord> getAllowedValues() {
144      return allowedValues;
145    }
146
147    public void addAllowedValues(@NonNull AllowedValuesRecord record) {
148      this.allowedValues.add(record);
149    }
150  }
151
152  public static final class AllowedValuesRecord {
153    @NonNull
154    private final IAllowedValuesConstraint allowedValues;
155    @NonNull
156    private final IDefinitionNodeItem<?, ?> location;
157    @NonNull
158    private final IDefinitionNodeItem<?, ?> target;
159
160    public AllowedValuesRecord(
161        @NonNull IAllowedValuesConstraint allowedValues,
162        @NonNull IDefinitionNodeItem<?, ?> location,
163        @NonNull IDefinitionNodeItem<?, ?> target) {
164      this.allowedValues = allowedValues;
165      this.location = location;
166      this.target = target;
167    }
168
169    @NonNull
170    public IAllowedValuesConstraint getAllowedValues() {
171      return allowedValues;
172    }
173
174    @NonNull
175    public IDefinitionNodeItem<?, ?> getLocation() {
176      return location;
177    }
178
179    @NonNull
180    public IDefinitionNodeItem<?, ?> getTarget() {
181      return target;
182    }
183  }
184}